From 7fd2f9089ab518fdc4e57b7d5e6b3bf8c130d97a Mon Sep 17 00:00:00 2001 From: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> Date: Mon, 27 Mar 2023 16:27:39 +0530 Subject: [PATCH 01/20] feat: Proto Columns Feature (#909) * feat: adding proto autogenerated code changes for proto column feature * feat: add implementation for Proto columns DDL * feat: add implementation for Proto columns DML * feat: add implementation for Proto columns DQL * feat: add NoneType check during Proto deserialization * feat: add code changes for Proto DDL support * feat: add required proto files to execute samples and tests * feat: add sample snippets for Proto columns DDL * feat: add tests for proto columns ddl, dml, dql snippets * feat: code refactoring * feat: remove staging endpoint from snippets.py * feat: comment refactor * feat: add license file * feat: update proto column data in insertion sample * feat: move column_info argument to the end to avoid breaking code --- .../types/spanner_database_admin.py | 25 ++ google/cloud/spanner_v1/_helpers.py | 43 ++- google/cloud/spanner_v1/data_types.py | 110 +++++++ google/cloud/spanner_v1/database.py | 20 +- google/cloud/spanner_v1/instance.py | 6 + google/cloud/spanner_v1/param_types.py | 34 +++ google/cloud/spanner_v1/session.py | 14 +- google/cloud/spanner_v1/snapshot.py | 28 +- google/cloud/spanner_v1/streamed.py | 10 +- google/cloud/spanner_v1/types/type.py | 11 + samples/samples/conftest.py | 9 +- samples/samples/snippets.py | 269 +++++++++++++++++- samples/samples/snippets_test.py | 72 ++++- samples/samples/testdata/descriptors.pb | Bin 0 -> 251 bytes samples/samples/testdata/singer.proto | 17 ++ samples/samples/testdata/singer_pb2.py | 41 +++ ...ixup_spanner_admin_database_v1_keywords.py | 4 +- .../test_database_admin.py | 4 + 18 files changed, 694 insertions(+), 23 deletions(-) create mode 100644 samples/samples/testdata/descriptors.pb create mode 100644 samples/samples/testdata/singer.proto create mode 100644 samples/samples/testdata/singer_pb2.py diff --git a/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py b/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py index b105e1f04d..163e49416e 100644 --- a/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py +++ b/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py @@ -331,6 +331,11 @@ class CreateDatabaseRequest(proto.Message): database_dialect (google.cloud.spanner_admin_database_v1.types.DatabaseDialect): Optional. The dialect of the Cloud Spanner Database. + proto_descriptors (bytes): + Proto descriptors used by CREATE/ALTER PROTO BUNDLE + statements in 'extra_statements' above. Contains a + protobuf-serialized + `google.protobuf.FileDescriptorSet `__. """ parent: str = proto.Field( @@ -355,6 +360,10 @@ class CreateDatabaseRequest(proto.Message): number=5, enum=common.DatabaseDialect, ) + proto_descriptors: bytes = proto.Field( + proto.BYTES, + number=6, + ) class CreateDatabaseMetadata(proto.Message): @@ -435,6 +444,10 @@ class UpdateDatabaseDdlRequest(proto.Message): underscore. If the named operation already exists, [UpdateDatabaseDdl][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabaseDdl] returns ``ALREADY_EXISTS``. + proto_descriptors (bytes): + Proto descriptors used by CREATE/ALTER PROTO BUNDLE + statements. Contains a protobuf-serialized + `google.protobuf.FileDescriptorSet `__. """ database: str = proto.Field( @@ -449,6 +462,10 @@ class UpdateDatabaseDdlRequest(proto.Message): proto.STRING, number=3, ) + proto_descriptors: bytes = proto.Field( + proto.BYTES, + number=4, + ) class UpdateDatabaseDdlMetadata(proto.Message): @@ -549,12 +566,20 @@ class GetDatabaseDdlResponse(proto.Message): A list of formatted DDL statements defining the schema of the database specified in the request. + proto_descriptors (bytes): + Proto descriptors stored in the database. Contains a + protobuf-serialized + `google.protobuf.FileDescriptorSet `__. """ statements: MutableSequence[str] = proto.RepeatedField( proto.STRING, number=1, ) + proto_descriptors: bytes = proto.Field( + proto.BYTES, + number=2, + ) class ListDatabaseOperationsRequest(proto.Message): diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index b364514d09..1d8425aa48 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -17,9 +17,12 @@ import datetime import decimal import math +import base64 from google.protobuf.struct_pb2 import ListValue from google.protobuf.struct_pb2 import Value +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper from google.api_core import datetime_helpers from google.cloud._helpers import _date_from_iso8601_date @@ -170,6 +173,12 @@ def _make_value_pb(value): return Value(null_value="NULL_VALUE") else: return Value(string_value=value) + if isinstance(value, Message): + value = value.SerializeToString() + if value is None: + return Value(null_value="NULL_VALUE") + else: + return Value(string_value=base64.b64encode(value)) raise ValueError("Unknown type: %s" % (value,)) @@ -198,7 +207,7 @@ def _make_list_value_pbs(values): return [_make_list_value_pb(row) for row in values] -def _parse_value_pb(value_pb, field_type): +def _parse_value_pb(value_pb, field_type, field_name, column_info=None): """Convert a Value protobuf to cell data. :type value_pb: :class:`~google.protobuf.struct_pb2.Value` @@ -207,6 +216,12 @@ def _parse_value_pb(value_pb, field_type): :type field_type: :class:`~google.cloud.spanner_v1.types.Type` :param field_type: type code for the value + :type field_name: str + :param field_name: column name + + :type column_info: dict + :param column_info: (Optional) dict of column name and column information + :rtype: varies on field_type :returns: value extracted from value_pb :raises ValueError: if unknown type is passed @@ -234,18 +249,38 @@ def _parse_value_pb(value_pb, field_type): return DatetimeWithNanoseconds.from_rfc3339(value_pb.string_value) elif type_code == TypeCode.ARRAY: return [ - _parse_value_pb(item_pb, field_type.array_element_type) + _parse_value_pb( + item_pb, field_type.array_element_type, field_name, column_info + ) for item_pb in value_pb.list_value.values ] elif type_code == TypeCode.STRUCT: return [ - _parse_value_pb(item_pb, field_type.struct_type.fields[i].type_) + _parse_value_pb( + item_pb, field_type.struct_type.fields[i].type_, field_name, column_info + ) for (i, item_pb) in enumerate(value_pb.list_value.values) ] elif type_code == TypeCode.NUMERIC: return decimal.Decimal(value_pb.string_value) elif type_code == TypeCode.JSON: return JsonObject.from_str(value_pb.string_value) + elif type_code == TypeCode.PROTO: + bytes_value = base64.b64decode(value_pb.string_value) + if column_info is not None and column_info.get(field_name) is not None: + proto_message = column_info.get(field_name) + if isinstance(proto_message, Message): + proto_message = proto_message.__deepcopy__() + proto_message.ParseFromString(bytes_value) + return proto_message + return bytes_value + elif type_code == TypeCode.ENUM: + int_value = int(value_pb.string_value) + if column_info is not None and column_info.get(field_name) is not None: + proto_enum = column_info.get(field_name) + if isinstance(proto_enum, EnumTypeWrapper): + return proto_enum.Name(int_value) + return int_value else: raise ValueError("Unknown type: %s" % (field_type,)) @@ -266,7 +301,7 @@ def _parse_list_value_pbs(rows, row_type): for row in rows: row_data = [] for value_pb, field in zip(row.values, row_type.fields): - row_data.append(_parse_value_pb(value_pb, field.type_)) + row_data.append(_parse_value_pb(value_pb, field.type_, field.name)) result.append(row_data) return result diff --git a/google/cloud/spanner_v1/data_types.py b/google/cloud/spanner_v1/data_types.py index fca0fcf982..130603afa9 100644 --- a/google/cloud/spanner_v1/data_types.py +++ b/google/cloud/spanner_v1/data_types.py @@ -15,6 +15,10 @@ """Custom data types for spanner.""" import json +import types + +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper class JsonObject(dict): @@ -71,3 +75,109 @@ def serialize(self): return json.dumps(self._array_value, sort_keys=True, separators=(",", ":")) return json.dumps(self, sort_keys=True, separators=(",", ":")) + + +def _proto_message(bytes_val, proto_message_object): + """Helper for :func:`get_proto_message`. + parses serialized protocol buffer bytes data into proto message. + + Args: + bytes_val (bytes): bytes object. + proto_message_object (Message): Message object for parsing + + Returns: + Message: parses serialized protocol buffer data into this message. + + Raises: + ValueError: if the input proto_message_object is not of type Message + """ + if isinstance(bytes_val, types.NoneType): + return None + + if not isinstance(bytes_val, bytes): + raise ValueError("Expected input bytes_val to be a string") + + proto_message = proto_message_object.__deepcopy__() + proto_message.ParseFromString(bytes_val) + return proto_message + + +def _proto_enum(int_val, proto_enum_object): + """Helper for :func:`get_proto_enum`. + parses int value into string containing the name of an enum value. + + Args: + int_val (int): integer value. + proto_enum_object (EnumTypeWrapper): Enum object. + + Returns: + str: string containing the name of an enum value. + + Raises: + ValueError: if the input proto_enum_object is not of type EnumTypeWrapper + """ + if isinstance(int_val, types.NoneType): + return None + + if not isinstance(int_val, int): + raise ValueError("Expected input int_val to be a integer") + + return proto_enum_object.Name(int_val) + + +def get_proto_message(bytes_string, proto_message_object): + """parses serialized protocol buffer bytes' data or its list into proto message or list of proto message. + + Args: + bytes_string (bytes or list[bytes]): bytes object. + proto_message_object (Message): Message object for parsing + + Returns: + Message or list[Message]: parses serialized protocol buffer data into this message. + + Raises: + ValueError: if the input proto_message_object is not of type Message + """ + if isinstance(bytes_string, types.NoneType): + return None + + if not isinstance(proto_message_object, Message): + raise ValueError("Input proto_message_object should be of type Message") + + if not isinstance(bytes_string, (bytes, list)): + raise ValueError( + "Expected input bytes_string to be a string or list of strings" + ) + + if isinstance(bytes_string, list): + return [_proto_message(item, proto_message_object) for item in bytes_string] + + return _proto_message(bytes_string, proto_message_object) + + +def get_proto_enum(int_value, proto_enum_object): + """parses int or list of int values into enum or list of enum values. + + Args: + int_value (int or list[int]): list of integer value. + proto_enum_object (EnumTypeWrapper): Enum object. + + Returns: + str or list[str]: list of strings containing the name of enum value. + + Raises: + ValueError: if the input int_list is not of type list + """ + if isinstance(int_value, types.NoneType): + return None + + if not isinstance(proto_enum_object, EnumTypeWrapper): + raise ValueError("Input proto_enum_object should be of type EnumTypeWrapper") + + if not isinstance(int_value, (int, list)): + raise ValueError("Expected input int_value to be a integer or list of integers") + + if isinstance(int_value, list): + return [_proto_enum(item, proto_enum_object) for item in int_value] + + return _proto_enum(int_value, proto_enum_object) diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index f919fa2c5e..e2f5a79810 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -124,6 +124,9 @@ class Database(object): (Optional) database dialect for the database :type database_role: str or None :param database_role: (Optional) user-assigned database_role for the session. + :type proto_descriptors: bytes + :param proto_descriptors: (Optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE + statements in 'ddl_statements' above. """ _spanner_api = None @@ -138,6 +141,7 @@ def __init__( encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, database_role=None, + proto_descriptors=None, ): self.database_id = database_id self._instance = instance @@ -155,6 +159,7 @@ def __init__( self._encryption_config = encryption_config self._database_dialect = database_dialect self._database_role = database_role + self._proto_descriptors = proto_descriptors if pool is None: pool = BurstyPool(database_role=database_role) @@ -328,6 +333,14 @@ def database_role(self): """ return self._database_role + @property + def proto_descriptors(self): + """Proto Descriptors for this database. + :rtype: bytes + :returns: bytes representing the proto descriptors for this database + """ + return self._proto_descriptors + @property def logger(self): """Logger used by the database. @@ -411,6 +424,7 @@ def create(self): extra_statements=list(self._ddl_statements), encryption_config=self._encryption_config, database_dialect=self._database_dialect, + proto_descriptors=self._proto_descriptors, ) future = api.create_database(request=request, metadata=metadata) return future @@ -447,6 +461,7 @@ def reload(self): metadata = _metadata_with_prefix(self.name) response = api.get_database_ddl(database=self.name, metadata=metadata) self._ddl_statements = tuple(response.statements) + self._proto_descriptors = response.proto_descriptors response = api.get_database(name=self.name, metadata=metadata) self._state = DatabasePB.State(response.state) self._create_time = response.create_time @@ -458,7 +473,7 @@ def reload(self): self._default_leader = response.default_leader self._database_dialect = response.database_dialect - def update_ddl(self, ddl_statements, operation_id=""): + def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None): """Update DDL for this database. Apply any configured schema from :attr:`ddl_statements`. @@ -470,6 +485,8 @@ def update_ddl(self, ddl_statements, operation_id=""): :param ddl_statements: a list of DDL statements to use on this database :type operation_id: str :param operation_id: (optional) a string ID for the long-running operation + :type proto_descriptors: bytes + :param proto_descriptors: (optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE statements :rtype: :class:`google.api_core.operation.Operation` :returns: an operation instance @@ -483,6 +500,7 @@ def update_ddl(self, ddl_statements, operation_id=""): database=self.name, statements=ddl_statements, operation_id=operation_id, + proto_descriptors=proto_descriptors, ) future = api.update_database_ddl(request=request, metadata=metadata) diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index f972f817b3..8e3254c311 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -432,6 +432,7 @@ def database( encryption_config=None, database_dialect=DatabaseDialect.DATABASE_DIALECT_UNSPECIFIED, database_role=None, + proto_descriptors=None, ): """Factory to create a database within this instance. @@ -467,6 +468,10 @@ def database( :param database_dialect: (Optional) database dialect for the database + :type proto_descriptors: bytes + :param proto_descriptors: (Optional) Proto descriptors used by CREATE/ALTER PROTO BUNDLE + statements in 'ddl_statements' above. + :rtype: :class:`~google.cloud.spanner_v1.database.Database` :returns: a database owned by this instance. """ @@ -479,6 +484,7 @@ def database( encryption_config=encryption_config, database_dialect=database_dialect, database_role=database_role, + proto_descriptors=proto_descriptors, ) def list_databases(self, page_size=None): diff --git a/google/cloud/spanner_v1/param_types.py b/google/cloud/spanner_v1/param_types.py index 0c03f7ecc6..198fe1c8ae 100644 --- a/google/cloud/spanner_v1/param_types.py +++ b/google/cloud/spanner_v1/param_types.py @@ -18,6 +18,8 @@ from google.cloud.spanner_v1 import TypeAnnotationCode from google.cloud.spanner_v1 import TypeCode from google.cloud.spanner_v1 import StructType +from google.protobuf.message import Message +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper # Scalar parameter types @@ -71,3 +73,35 @@ def Struct(fields): :returns: the appropriate struct-type protobuf """ return Type(code=TypeCode.STRUCT, struct_type=StructType(fields=fields)) + + +def ProtoMessage(proto_message_object): + """Construct a proto message type description protobuf. + + :type proto_message_object: :class:`google.protobuf.message.Message` + :param proto_message_object: the proto message instance + + :rtype: :class:`type_pb2.Type` + :returns: the appropriate proto-message-type protobuf + """ + if not isinstance(proto_message_object, Message): + raise ValueError("Expected input object of type Proto Message.") + return Type( + code=TypeCode.PROTO, proto_type_fqn=proto_message_object.DESCRIPTOR.full_name + ) + + +def ProtoEnum(proto_enum_object): + """Construct a proto enum type description protobuf. + + :type proto_enum_object: :class:`google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper` + :param proto_enum_object: the proto enum instance + + :rtype: :class:`type_pb2.Type` + :returns: the appropriate proto-enum-type protobuf + """ + if not isinstance(proto_enum_object, EnumTypeWrapper): + raise ValueError("Expected input object of type Proto Enum") + return Type( + code=TypeCode.ENUM, proto_type_fqn=proto_enum_object.DESCRIPTOR.full_name + ) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 5b1ca6fbb8..50881b3f69 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -213,7 +213,7 @@ def snapshot(self, **kw): return Snapshot(self, **kw) - def read(self, table, columns, keyset, index="", limit=0): + def read(self, table, columns, keyset, index="", limit=0, column_info=None): """Perform a ``StreamingRead`` API request for rows in a table. :type table: str @@ -232,10 +232,15 @@ def read(self, table, columns, keyset, index="", limit=0): :type limit: int :param limit: (Optional) maximum number of rows to return + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ - return self.snapshot().read(table, columns, keyset, index, limit) + return self.snapshot().read( + table, columns, keyset, index, limit, column_info=column_info + ) def execute_sql( self, @@ -247,6 +252,7 @@ def execute_sql( request_options=None, retry=method.DEFAULT, timeout=method.DEFAULT, + column_info=None, ): """Perform an ``ExecuteStreamingSql`` API request. @@ -286,6 +292,9 @@ def execute_sql( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. """ @@ -298,6 +307,7 @@ def execute_sql( request_options=request_options, retry=retry, timeout=timeout, + column_info=column_info, ) def batch(self): diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index f1fff8b533..1c50778c4c 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -170,6 +170,7 @@ def read( *, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + column_info=None, ): """Perform a ``StreamingRead`` API request for rows in a table. @@ -210,6 +211,9 @@ def read( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information + :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -271,9 +275,11 @@ def read( ) self._read_request_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet( + iterator, source=self, column_info=column_info + ) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) else: iterator = _restart_on_unavailable( restart, @@ -287,9 +293,9 @@ def read( self._read_request_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, source=self, column_info=column_info) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) def execute_sql( self, @@ -302,6 +308,7 @@ def execute_sql( partition=None, retry=gapic_v1.method.DEFAULT, timeout=gapic_v1.method.DEFAULT, + column_info=None, ): """Perform an ``ExecuteStreamingSql`` API request. @@ -351,6 +358,9 @@ def execute_sql( :type timeout: float :param timeout: (Optional) The timeout for this request. + :type column_info: dict + :param column_info: (Optional) dict of mapping between column names and additional column information + :raises ValueError: for reuse of single-use snapshots, or if a transaction ID is already pending for multiple-use snapshots. @@ -426,9 +436,11 @@ def execute_sql( self._execute_sql_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet( + iterator, source=self, column_info=column_info + ) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) else: iterator = _restart_on_unavailable( restart, @@ -443,9 +455,9 @@ def execute_sql( self._execute_sql_count += 1 if self._multi_use: - return StreamedResultSet(iterator, source=self) + return StreamedResultSet(iterator, source=self, column_info=column_info) else: - return StreamedResultSet(iterator) + return StreamedResultSet(iterator, column_info=column_info) def partition_read( self, diff --git a/google/cloud/spanner_v1/streamed.py b/google/cloud/spanner_v1/streamed.py index 80a452d558..850ddc1726 100644 --- a/google/cloud/spanner_v1/streamed.py +++ b/google/cloud/spanner_v1/streamed.py @@ -37,7 +37,7 @@ class StreamedResultSet(object): :param source: Snapshot from which the result set was fetched. """ - def __init__(self, response_iterator, source=None): + def __init__(self, response_iterator, source=None, column_info=None): self._response_iterator = response_iterator self._rows = [] # Fully-processed rows self._metadata = None # Until set from first PRS @@ -45,6 +45,7 @@ def __init__(self, response_iterator, source=None): self._current_row = [] # Accumulated values for incomplete row self._pending_chunk = None # Incomplete value self._source = source # Source snapshot + self._column_info = column_info # Column information @property def fields(self): @@ -99,10 +100,15 @@ def _merge_values(self, values): :param values: non-chunked values from partial result set. """ field_types = [field.type_ for field in self.fields] + field_names = [field.name for field in self.fields] width = len(field_types) index = len(self._current_row) for value in values: - self._current_row.append(_parse_value_pb(value, field_types[index])) + self._current_row.append( + _parse_value_pb( + value, field_types[index], field_names[index], self._column_info + ) + ) index += 1 if index == width: self._rows.append(self._current_row) diff --git a/google/cloud/spanner_v1/types/type.py b/google/cloud/spanner_v1/types/type.py index 1c9626002c..0d378a2efa 100644 --- a/google/cloud/spanner_v1/types/type.py +++ b/google/cloud/spanner_v1/types/type.py @@ -105,6 +105,8 @@ class TypeCode(proto.Enum): STRUCT = 9 NUMERIC = 10 JSON = 11 + PROTO = 13 + ENUM = 14 class TypeAnnotationCode(proto.Enum): @@ -170,6 +172,11 @@ class Type(proto.Message): typically is not needed to process the content of a value (it doesn't affect serialization) and clients can ignore it on the read path. + proto_type_fqn (str): + If [code][] == [PROTO][TypeCode.PROTO] or [code][] == + [ENUM][TypeCode.ENUM], then ``proto_type_fqn`` is the fully + qualified name of the proto type representing the proto/enum + definition. """ code: "TypeCode" = proto.Field( @@ -192,6 +199,10 @@ class Type(proto.Message): number=4, enum="TypeAnnotationCode", ) + proto_type_fqn: str = proto.Field( + proto.STRING, + number=5, + ) class StructType(proto.Message): diff --git a/samples/samples/conftest.py b/samples/samples/conftest.py index c63548c460..6747199022 100644 --- a/samples/samples/conftest.py +++ b/samples/samples/conftest.py @@ -114,6 +114,11 @@ def multi_region_instance_config(spanner_client): return "{}/instanceConfigs/{}".format(spanner_client.project_name, "nam3") +@pytest.fixture(scope="module") +def proto_descriptor_file(): + return open("../../samples/samples/testdata/descriptors.pb", 'rb').read() + + @pytest.fixture(scope="module") def sample_instance( spanner_client, @@ -208,7 +213,8 @@ def sample_database( sample_instance, database_id, database_ddl, - database_dialect): + database_dialect, + proto_descriptor_file): if database_dialect == DatabaseDialect.POSTGRESQL: sample_database = sample_instance.database( database_id, @@ -236,6 +242,7 @@ def sample_database( sample_database = sample_instance.database( database_id, ddl_statements=database_ddl, + proto_descriptors=proto_descriptor_file ) if not sample_database.exists(): diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index a447121010..0542031f96 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -31,10 +31,12 @@ from google.cloud import spanner from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import param_types -from google.type import expr_pb2 +from google.cloud.spanner_v1.data_types import JsonObject, get_proto_message, get_proto_enum from google.iam.v1 import policy_pb2 -from google.cloud.spanner_v1.data_types import JsonObject from google.protobuf import field_mask_pb2 # type: ignore +from google.type import expr_pb2 +from samples.samples.testdata import singer_pb2 + OPERATION_TIMEOUT_SECONDS = 240 @@ -278,6 +280,46 @@ def create_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_create_database_with_default_leader] +# [START spanner_create_database_with_proto_descriptors] +def create_database_with_proto_descriptors(instance_id, database_id): + """Creates a database with proto descriptors and tables with proto columns for sample data.""" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + + # reads proto descriptor file as bytes + proto_descriptor_file = open("testdata/descriptors.pb", 'rb').read() + + database = instance.database( + database_id, + ddl_statements=[ + """CREATE PROTO BUNDLE ( + spanner.examples.music.SingerInfo, + spanner.examples.music.Genre, + )""", + """CREATE TABLE SingersProto ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo spanner.examples.music.SingerInfo, + SingerGenre spanner.examples.music.Genre, + SingerInfoArray ARRAY, + SingerGenreArray ARRAY, + ) PRIMARY KEY (SingerId)""", + ], + proto_descriptors=proto_descriptor_file + ) + + operation = database.create() + + print("Waiting for operation to complete...") + operation.result(OPERATION_TIMEOUT_SECONDS) + + print("Created database {} with proto descriptors on instance {}".format(database_id, instance_id)) + + +# [END spanner_create_database_with_proto_descriptors] + + # [START spanner_update_database_with_default_leader] def update_database_with_default_leader(instance_id, database_id, default_leader): """Updates a database with tables with a default leader.""" @@ -306,6 +348,46 @@ def update_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_update_database_with_default_leader] +# [START spanner_update_database_with_proto_descriptors] +def update_database_with_proto_descriptors(instance_id, database_id): + """Updates a database with tables with a default leader.""" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + + database = instance.database(database_id) + proto_descriptor_file = open("testdata/descriptors.pb", 'rb').read() + + operation = database.update_ddl( + [ + """CREATE PROTO BUNDLE ( + spanner.examples.music.SingerInfo, + spanner.examples.music.Genre, + )""", + """CREATE TABLE SingersProto ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + SingerInfo spanner.examples.music.SingerInfo, + SingerGenre spanner.examples.music.Genre, + ) PRIMARY KEY (SingerId)""", + ], + proto_descriptors=proto_descriptor_file + ) + print("Waiting for operation to complete...") + operation.result(OPERATION_TIMEOUT_SECONDS) + + database.reload() + + print( + "Database {} updated with proto descriptors".format( + database.name + ) + ) + + +# [END spanner_update_database_with_proto_descriptors] + + # [START spanner_get_database_ddl] def get_database_ddl(instance_id, database_id): """Gets the database DDL statements.""" @@ -316,6 +398,7 @@ def get_database_ddl(instance_id, database_id): print("Retrieved database DDL for {}".format(database_id)) for statement in ddl.statements: print(statement) + print(ddl.proto_descriptors) # [END spanner_get_database_ddl] @@ -2428,6 +2511,165 @@ def enable_fine_grained_access( # [END spanner_enable_fine_grained_access] +# [START spanner_insert_proto_columns_data_with_dml] +def insert_proto_columns_data_with_dml(instance_id, database_id): + """Inserts sample proto column data into the given database using a DML statement.""" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 1 + singer_info.birth_date = "January" + singer_info.nationality = "Country1" + singer_info.genre = singer_pb2.Genre.ROCK + + singer_info_array = [singer_info, None] + singer_genre_array = [singer_pb2.Genre.ROCK, None] + + def insert_singers_with_proto_column(transaction): + row_ct = transaction.execute_update( + "INSERT INTO SingersProto (SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray," + " SingerGenreArray) " + " VALUES (1, 'Virginia', 'Watson', @singerInfo, @singerGenre, @singerInfoArray, @singerGenreArray)", + params={ + "singerInfo": singer_info, + "singerGenre": singer_pb2.Genre.ROCK, + "singerInfoArray": singer_info_array, + "singerGenreArray": singer_genre_array + }, + param_types={ + "singerInfo": param_types.ProtoMessage(singer_info), + "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), + "singerInfoArray": param_types.Array(param_types.ProtoMessage(singer_info)), + "singerGenreArray": param_types.Array(param_types.ProtoEnum(singer_pb2.Genre)) + } + ) + + print("{} record(s) inserted.".format(row_ct)) + + database.run_in_transaction(insert_singers_with_proto_column) + + +# [END spanner_insert_proto_columns_data_with_dml] + + +# [START spanner_insert_proto_columns_data] +def insert_proto_columns_data(instance_id, database_id): + """Inserts sample proto column data into the given database. + + The database and table must already exist and can be created using + `create_database`. + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 2 + singer_info.birth_date = "February" + singer_info.nationality = "Country2" + singer_info.genre = singer_pb2.Genre.FOLK + + singer_info_array = [singer_info] + singer_genre_array = [singer_pb2.Genre.FOLK] + + with database.batch() as batch: + batch.insert( + table="SingersProto", + columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", + "SingerGenreArray"), + values=[ + (2, "Marc", "Richards", singer_info, singer_pb2.Genre.ROCK, singer_info_array, singer_genre_array), + (3, "Catalina", "Smith", None, None, None, None), + ], + ) + + print("Inserted data.") + + +# [END spanner_insert_proto_columns_data] + + +# [START spanner_read_proto_columns_data] +def read_proto_columns_data(instance_id, database_id): + """Reads sample proto column data from the database.""" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.snapshot() as snapshot: + keyset = spanner.KeySet(all_=True) + results = snapshot.read( + table="SingersProto", + columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", "SingerGenreArray"), + keyset=keyset, + column_info={"SingerInfo": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenreArray": singer_pb2.Genre}, + ) + + for row in results: + print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " + "SingerGenreArray: {}".format(*row)) + + +# [END spanner_read_proto_columns_data] + + +# [START spanner_read_proto_columns_data_using_helper_method] +def read_proto_columns_data_using_helper_method(instance_id, database_id): + """Reads sample proto column data from the database.""" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.snapshot() as snapshot: + keyset = spanner.KeySet(all_=True) + results = snapshot.read( + table="SingersProto", + columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", "SingerGenreArray"), + keyset=keyset, + ) + + for row in results: + singer_info_proto_msg = get_proto_message(row[3], singer_pb2.SingerInfo()) + singer_genre_proto_enum = get_proto_enum(row[4], singer_pb2.Genre) + singer_info_list = get_proto_message(row[5], singer_pb2.SingerInfo()) + singer_genre_list = get_proto_enum(row[6], singer_pb2.Genre) + print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, SingerInfoArray: {}, " + "SingerGenreArray: {}".format(row[0], row[1], row[2], singer_info_proto_msg, singer_genre_proto_enum, + singer_info_list, singer_genre_list)) + + +# [END spanner_read_proto_columns_data_using_helper_method] + + +# [START spanner_query_proto_columns_data] +def query_proto_columns_data(instance_id, database_id): + """Queries sample proto column data from the database using SQL.""" + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + "SELECT SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray, SingerGenreArray FROM SingersProto", + column_info={"SingerInfo": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenreArray": singer_pb2.Genre}, + ) + + for row in results: + print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " + "SingerGenreArray: {}".format(*row)) + + +# [END spanner_query_proto_columns_data] + + if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -2440,6 +2682,8 @@ def enable_fine_grained_access( subparsers = parser.add_subparsers(dest="command") subparsers.add_parser("create_instance", help=create_instance.__doc__) subparsers.add_parser("create_database", help=create_database.__doc__) + subparsers.add_parser("create_database_with_proto_descriptors", help=create_database_with_proto_descriptors.__doc__) + subparsers.add_parser("get_database_ddl", help=get_database_ddl.__doc__) subparsers.add_parser("insert_data", help=insert_data.__doc__) subparsers.add_parser("delete_data", help=delete_data.__doc__) subparsers.add_parser("query_data", help=query_data.__doc__) @@ -2544,6 +2788,13 @@ def enable_fine_grained_access( "read_data_with_database_role", help=read_data_with_database_role.__doc__ ) subparsers.add_parser("list_database_roles", help=list_database_roles.__doc__) + subparsers.add_parser("insert_proto_columns_data_with_dml", help=insert_proto_columns_data_with_dml.__doc__) + subparsers.add_parser("insert_proto_columns_data", help=insert_proto_columns_data.__doc__) + subparsers.add_parser("read_proto_columns_data", help=read_proto_columns_data.__doc__) + subparsers.add_parser( + "read_proto_columns_data_using_helper_method", help=read_proto_columns_data_using_helper_method.__doc__ + ) + subparsers.add_parser("query_proto_columns_data", help=query_proto_columns_data.__doc__) enable_fine_grained_access_parser = subparsers.add_parser( "enable_fine_grained_access", help=enable_fine_grained_access.__doc__ ) @@ -2561,6 +2812,10 @@ def enable_fine_grained_access( create_instance(args.instance_id) elif args.command == "create_database": create_database(args.instance_id, args.database_id) + elif args.command == "create_database_with_proto_descriptors": + create_database_with_proto_descriptors(args.instance_id, args.database_id) + elif args.command == "get_database_ddl": + get_database_ddl(args.instance_id, args.database_id) elif args.command == "insert_data": insert_data(args.instance_id, args.database_id) elif args.command == "delete_data": @@ -2683,3 +2938,13 @@ def enable_fine_grained_access( args.database_role, args.title, ) + elif args.command == "insert_proto_columns_data_with_dml": + insert_proto_columns_data_with_dml(args.instance_id, args.database_id) + elif args.command == "insert_proto_columns_data": + insert_proto_columns_data(args.instance_id, args.database_id) + elif args.command == "read_proto_columns_data": + read_proto_columns_data(args.instance_id, args.database_id) + elif args.command == "read_proto_columns_data_using_helper_method": + read_proto_columns_data_using_helper_method(args.instance_id, args.database_id) + elif args.command == "query_proto_columns_data": + query_proto_columns_data(args.instance_id, args.database_id) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 6d5822e37b..af80e0b535 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -44,6 +44,25 @@ INTERLEAVE IN PARENT Singers ON DELETE CASCADE """ +CREATE_TABLE_SINGERS_PROTO = """\ +CREATE TABLE SingersProto ( +SingerId INT64 NOT NULL, +FirstName STRING(1024), +LastName STRING(1024), +SingerInfo spanner.examples.music.SingerInfo, +SingerGenre spanner.examples.music.Genre, +SingerInfoArray ARRAY, +SingerGenreArray ARRAY, +) PRIMARY KEY (SingerId) +""" + +CREATE_PROTO_BUNDLE = """\ +CREATE PROTO BUNDLE ( + spanner.examples.music.SingerInfo, + spanner.examples.music.Genre, + ) +""" + retry_429 = RetryErrors(exceptions.ResourceExhausted, delay=15) @@ -100,7 +119,7 @@ def database_ddl(): Sample testcase modules can override as needed. """ - return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS] + return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS, CREATE_PROTO_BUNDLE, CREATE_TABLE_SINGERS_PROTO] @pytest.fixture(scope="module") @@ -165,6 +184,13 @@ def test_create_database_with_encryption_config( assert kms_key_name in out +def test_create_database_with_proto_descriptors(capsys, instance_id, database_id): + snippets.create_database_with_proto_descriptors(instance_id, database_id) + out, _ = capsys.readouterr() + assert database_id in out + assert instance_id in out + + def test_get_instance_config(capsys): instance_config = "nam6" snippets.get_instance_config(instance_config) @@ -781,3 +807,47 @@ def test_list_database_roles(capsys, instance_id, sample_database): snippets.list_database_roles(instance_id, sample_database.database_id) out, _ = capsys.readouterr() assert "new_parent" in out + + +@pytest.mark.dependency(name="insert_proto_columns_data_dml") +def test_insert_proto_columns_data_with_dml(capsys, instance_id, sample_database): + snippets.insert_proto_columns_data_with_dml(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "record(s) inserted" in out + + +@pytest.mark.dependency(name="insert_proto_columns_data") +def test_insert_proto_columns_data(capsys, instance_id, sample_database): + snippets.insert_proto_columns_data(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + assert "Inserted data" in out + + +@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) +def test_query_proto_columns_data(capsys, instance_id, sample_database): + snippets.query_proto_columns_data(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + + assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out + assert "SingerId: 2, FirstName: Marc, LastName: Richards" in out + assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out + + +@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) +def test_read_proto_columns_data(capsys, instance_id, sample_database): + snippets.read_proto_columns_data(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + + assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out + assert "SingerId: 2, FirstName: Marc, LastName: Richards" in out + assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out + + +@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) +def test_read_proto_columns_data_using_helper_method(capsys, instance_id, sample_database): + snippets.read_proto_columns_data_using_helper_method(instance_id, sample_database.database_id) + out, _ = capsys.readouterr() + + assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out + assert "SingerId: 2, FirstName: Marc, LastName: Richards" in out + assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out diff --git a/samples/samples/testdata/descriptors.pb b/samples/samples/testdata/descriptors.pb new file mode 100644 index 0000000000000000000000000000000000000000..3ebb79420b3ffd2ca3b3b57433a4a10bfa22b675 GIT binary patch literal 251 zcmd=3!N|o^oSB!NTBKJ{lwXoBBvxFIn3o6SrdA~87UZNB>*bafXC^DnXXN4v1}pT; zOUoCM=Hi5Ci_c7vU{qk#U=HGd2zaIl$#QWeWfqlW#HS>dq)IRWWjTX5!6Gg|0U-r0 z?!3g3%>2B>oXnC+31+Z7vXGE57i)TIUQwz93s8>FNLCNKqx9TCih>|&we+}H!F(Zh lF6IFL009Oe4lWii$EYX)Mi9%*-^W{k3B(HWclH)w1^^+RM@0Yt literal 0 HcmV?d00001 diff --git a/samples/samples/testdata/singer.proto b/samples/samples/testdata/singer.proto new file mode 100644 index 0000000000..8dde1bccae --- /dev/null +++ b/samples/samples/testdata/singer.proto @@ -0,0 +1,17 @@ +syntax = "proto2"; + +package spanner.examples.music; + +message SingerInfo { + optional int64 singer_id = 1; + optional string birth_date = 2; + optional string nationality = 3; + optional Genre genre = 4; +} + +enum Genre { + POP = 0; + JAZZ = 1; + FOLK = 2; + ROCK = 3; +} diff --git a/samples/samples/testdata/singer_pb2.py b/samples/samples/testdata/singer_pb2.py new file mode 100644 index 0000000000..cdb44c74af --- /dev/null +++ b/samples/samples/testdata/singer_pb2.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: singer.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csinger.proto\x12\x16spanner.examples.music\"v\n\nSingerInfo\x12\x11\n\tsinger_id\x18\x01 \x01(\x03\x12\x12\n\nbirth_date\x18\x02 \x01(\t\x12\x13\n\x0bnationality\x18\x03 \x01(\t\x12,\n\x05genre\x18\x04 \x01(\x0e\x32\x1d.spanner.examples.music.Genre*.\n\x05Genre\x12\x07\n\x03POP\x10\x00\x12\x08\n\x04JAZZ\x10\x01\x12\x08\n\x04\x46OLK\x10\x02\x12\x08\n\x04ROCK\x10\x03') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'singer_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _GENRE._serialized_start=160 + _GENRE._serialized_end=206 + _SINGERINFO._serialized_start=40 + _SINGERINFO._serialized_end=158 +# @@protoc_insertion_point(module_scope) diff --git a/scripts/fixup_spanner_admin_database_v1_keywords.py b/scripts/fixup_spanner_admin_database_v1_keywords.py index ad31a48c81..6049c63dd2 100644 --- a/scripts/fixup_spanner_admin_database_v1_keywords.py +++ b/scripts/fixup_spanner_admin_database_v1_keywords.py @@ -41,7 +41,7 @@ class spanner_admin_databaseCallTransformer(cst.CSTTransformer): METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { 'copy_backup': ('parent', 'backup_id', 'source_backup', 'expire_time', 'encryption_config', ), 'create_backup': ('parent', 'backup_id', 'backup', 'encryption_config', ), - 'create_database': ('parent', 'create_statement', 'extra_statements', 'encryption_config', 'database_dialect', ), + 'create_database': ('parent', 'create_statement', 'extra_statements', 'encryption_config', 'database_dialect', 'proto_descriptors', ), 'delete_backup': ('name', ), 'drop_database': ('database', ), 'get_backup': ('name', ), @@ -57,7 +57,7 @@ class spanner_admin_databaseCallTransformer(cst.CSTTransformer): 'set_iam_policy': ('resource', 'policy', 'update_mask', ), 'test_iam_permissions': ('resource', 'permissions', ), 'update_backup': ('backup', 'update_mask', ), - 'update_database_ddl': ('database', 'statements', 'operation_id', ), + 'update_database_ddl': ('database', 'statements', 'operation_id', 'proto_descriptors', ), } def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: diff --git a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py index b9041dd1d2..21cef8415f 100644 --- a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py +++ b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py @@ -2106,6 +2106,7 @@ def test_get_database_ddl(request_type, transport: str = "grpc"): # Designate an appropriate return value for the call. call.return_value = spanner_database_admin.GetDatabaseDdlResponse( statements=["statements_value"], + proto_descriptors=b"proto_descriptors_blob", ) response = client.get_database_ddl(request) @@ -2117,6 +2118,7 @@ def test_get_database_ddl(request_type, transport: str = "grpc"): # Establish that the response is the type that we expect. assert isinstance(response, spanner_database_admin.GetDatabaseDdlResponse) assert response.statements == ["statements_value"] + assert response.proto_descriptors == b"proto_descriptors_blob" def test_get_database_ddl_empty_call(): @@ -2155,6 +2157,7 @@ async def test_get_database_ddl_async( call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( spanner_database_admin.GetDatabaseDdlResponse( statements=["statements_value"], + proto_descriptors=b"proto_descriptors_blob", ) ) response = await client.get_database_ddl(request) @@ -2167,6 +2170,7 @@ async def test_get_database_ddl_async( # Establish that the response is the type that we expect. assert isinstance(response, spanner_database_admin.GetDatabaseDdlResponse) assert response.statements == ["statements_value"] + assert response.proto_descriptors == b"proto_descriptors_blob" @pytest.mark.asyncio From ee0bf5bebd7359c483862284b683fc1eadb815ad Mon Sep 17 00:00:00 2001 From: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> Date: Mon, 3 Apr 2023 09:59:32 +0530 Subject: [PATCH 02/20] feat: Proto column feature tests and samples (#921) * feat: add integration tests for Proto Columns * feat: add unit tests for Proto Columns * feat: update tests to add column_info argument at end * feat: remove deepcopy during deserialization of proto message * feat: tests refactoring * feat: integration tests refactoring * feat: samples and sample tests refactoring * feat: lint tests folder * feat:lint samples directory * feat: stop running emulator with proto ddl commands * feat: close the file after reading * feat: update protobuf version lower bound to >3.20 to check proto message compatibility * feat: update setup for snippets_tests.py file * feat: add integration tests * feat: remove duplicate integration tests * feat: add proto_descriptor parameter to required tests * feat: add compatibility tests between Proto message, Bytes and Proto Enum, Int64 * feat: add index tests for proto columns * feat: replace duplicates with sample data * feat: update protobuf lower bound version in setup.py file to add support for proto messages and enum * feat: lint fixes * feat: lint fix * feat: tests refactoring * feat: change comment from dml to dql for read * feat: tests refactoring for update db operation --- google/cloud/spanner_v1/_helpers.py | 6 +- samples/samples/conftest.py | 37 +++- samples/samples/snippets.py | 311 +++++++++++++++++---------- samples/samples/snippets_test.py | 77 +++++-- setup.py | 2 +- testing/constraints-3.7.txt | 2 +- tests/_fixtures.py | 22 ++ tests/system/_helpers.py | 2 + tests/system/_sample_data.py | 27 ++- tests/system/conftest.py | 16 +- tests/system/test_backup_api.py | 5 +- tests/system/test_database_api.py | 79 +++++-- tests/system/test_session_api.py | 182 ++++++++++++++-- tests/system/testdata/descriptors.pb | Bin 0 -> 251 bytes tests/unit/test__helpers.py | 104 +++++++-- tests/unit/test_database.py | 80 +++++++ tests/unit/test_instance.py | 3 + tests/unit/test_param_types.py | 34 +++ tests/unit/test_session.py | 10 +- 19 files changed, 807 insertions(+), 192 deletions(-) create mode 100644 tests/system/testdata/descriptors.pb diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 1d8425aa48..bf1094180e 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -268,9 +268,9 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None): elif type_code == TypeCode.PROTO: bytes_value = base64.b64decode(value_pb.string_value) if column_info is not None and column_info.get(field_name) is not None: - proto_message = column_info.get(field_name) - if isinstance(proto_message, Message): - proto_message = proto_message.__deepcopy__() + default_proto_message = column_info.get(field_name) + if isinstance(default_proto_message, Message): + proto_message = type(default_proto_message)() proto_message.ParseFromString(bytes_value) return proto_message return bytes_value diff --git a/samples/samples/conftest.py b/samples/samples/conftest.py index 6747199022..674d61099d 100644 --- a/samples/samples/conftest.py +++ b/samples/samples/conftest.py @@ -116,7 +116,13 @@ def multi_region_instance_config(spanner_client): @pytest.fixture(scope="module") def proto_descriptor_file(): - return open("../../samples/samples/testdata/descriptors.pb", 'rb').read() + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + file = open(filename, "rb") + yield file.read() + file.close() @pytest.fixture(scope="module") @@ -213,8 +219,7 @@ def sample_database( sample_instance, database_id, database_ddl, - database_dialect, - proto_descriptor_file): + database_dialect): if database_dialect == DatabaseDialect.POSTGRESQL: sample_database = sample_instance.database( database_id, @@ -242,7 +247,6 @@ def sample_database( sample_database = sample_instance.database( database_id, ddl_statements=database_ddl, - proto_descriptors=proto_descriptor_file ) if not sample_database.exists(): @@ -254,6 +258,31 @@ def sample_database( sample_database.drop() +@pytest.fixture(scope="module") +def sample_database_for_proto_columns( + spanner_client, + sample_instance, + database_id, + database_ddl_for_proto_columns, + database_dialect, + proto_descriptor_file, +): + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + sample_database = sample_instance.database( + database_id, + ddl_statements=database_ddl_for_proto_columns, + proto_descriptors=proto_descriptor_file, + ) + + if not sample_database.exists(): + operation = sample_database.create() + operation.result(OPERATION_TIMEOUT_SECONDS) + + yield sample_database + + sample_database.drop() + + @pytest.fixture(scope="module") def kms_key_name(spanner_client): return "projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}".format( diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 0542031f96..855a78949b 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -31,7 +31,11 @@ from google.cloud import spanner from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import param_types -from google.cloud.spanner_v1.data_types import JsonObject, get_proto_message, get_proto_enum +from google.cloud.spanner_v1.data_types import ( + JsonObject, + get_proto_message, + get_proto_enum, +) from google.iam.v1 import policy_pb2 from google.protobuf import field_mask_pb2 # type: ignore from google.type import expr_pb2 @@ -280,14 +284,20 @@ def create_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_create_database_with_default_leader] -# [START spanner_create_database_with_proto_descriptors] -def create_database_with_proto_descriptors(instance_id, database_id): +# [START spanner_create_database_with_proto_descriptor] +def create_database_with_proto_descriptor(instance_id, database_id): """Creates a database with proto descriptors and tables with proto columns for sample data.""" + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) # reads proto descriptor file as bytes - proto_descriptor_file = open("testdata/descriptors.pb", 'rb').read() + proto_descriptor_file = open(filename, "rb") + proto_descriptor = proto_descriptor_file.read() database = instance.database( database_id, @@ -296,7 +306,7 @@ def create_database_with_proto_descriptors(instance_id, database_id): spanner.examples.music.SingerInfo, spanner.examples.music.Genre, )""", - """CREATE TABLE SingersProto ( + """CREATE TABLE Singers ( SingerId INT64 NOT NULL, FirstName STRING(1024), LastName STRING(1024), @@ -306,18 +316,23 @@ def create_database_with_proto_descriptors(instance_id, database_id): SingerGenreArray ARRAY, ) PRIMARY KEY (SingerId)""", ], - proto_descriptors=proto_descriptor_file + proto_descriptors=proto_descriptor, ) operation = database.create() print("Waiting for operation to complete...") operation.result(OPERATION_TIMEOUT_SECONDS) + proto_descriptor_file.close() - print("Created database {} with proto descriptors on instance {}".format(database_id, instance_id)) + print( + "Created database {} with proto descriptors on instance {}".format( + database_id, instance_id + ) + ) -# [END spanner_create_database_with_proto_descriptors] +# [END spanner_create_database_with_proto_descriptor] # [START spanner_update_database_with_default_leader] @@ -348,14 +363,20 @@ def update_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_update_database_with_default_leader] -# [START spanner_update_database_with_proto_descriptors] -def update_database_with_proto_descriptors(instance_id, database_id): +# [START spanner_update_database_with_proto_descriptor] +def update_database_with_proto_descriptor(instance_id, database_id): """Updates a database with tables with a default leader.""" + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) - proto_descriptor_file = open("testdata/descriptors.pb", 'rb').read() + proto_descriptor_file = open(filename, "rb") + proto_descriptor = proto_descriptor_file.read() operation = database.update_ddl( [ @@ -363,29 +384,28 @@ def update_database_with_proto_descriptors(instance_id, database_id): spanner.examples.music.SingerInfo, spanner.examples.music.Genre, )""", - """CREATE TABLE SingersProto ( + """CREATE TABLE Singers ( SingerId INT64 NOT NULL, FirstName STRING(1024), LastName STRING(1024), SingerInfo spanner.examples.music.SingerInfo, SingerGenre spanner.examples.music.Genre, + SingerInfoArray ARRAY, + SingerGenreArray ARRAY, ) PRIMARY KEY (SingerId)""", ], - proto_descriptors=proto_descriptor_file + proto_descriptors=proto_descriptor, ) print("Waiting for operation to complete...") operation.result(OPERATION_TIMEOUT_SECONDS) + proto_descriptor_file.close() database.reload() - print( - "Database {} updated with proto descriptors".format( - database.name - ) - ) + print("Database {} updated with proto descriptors".format(database.name)) -# [END spanner_update_database_with_proto_descriptors] +# [END spanner_update_database_with_proto_descriptor] # [START spanner_get_database_ddl] @@ -398,7 +418,6 @@ def get_database_ddl(instance_id, database_id): print("Retrieved database DDL for {}".format(database_id)) for statement in ddl.statements: print(statement) - print(ddl.proto_descriptors) # [END spanner_get_database_ddl] @@ -2511,8 +2530,60 @@ def enable_fine_grained_access( # [END spanner_enable_fine_grained_access] -# [START spanner_insert_proto_columns_data_with_dml] -def insert_proto_columns_data_with_dml(instance_id, database_id): +# [START spanner_insert_proto_columns_data] +def insert_proto_columns_data(instance_id, database_id): + """Inserts sample proto column data into the given database. + + The database and table must already exist and can be created using + `create_database`. + """ + spanner_client = spanner.Client(client_options={'api_endpoint':'staging-wrenchworks.sandbox.googleapis.com'}) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 2 + singer_info.birth_date = "February" + singer_info.nationality = "Country2" + singer_info.genre = singer_pb2.Genre.FOLK + + singer_info_array = [singer_info] + singer_genre_array = [singer_pb2.Genre.FOLK] + + with database.batch() as batch: + batch.insert( + table="Singers", + columns=( + "SingerId", + "FirstName", + "LastName", + "SingerInfo", + "SingerGenre", + "SingerInfoArray", + "SingerGenreArray", + ), + values=[ + ( + 2, + "Marc", + "Richards", + singer_info, + singer_pb2.Genre.ROCK, + singer_info_array, + singer_genre_array, + ), + (3, "Catalina", "Smith", None, None, None, None), + ], + ) + + print("Inserted data.") + + +# [END spanner_insert_proto_columns_data] + + +# [START spanner_insert_proto_columns_data_using_dml] +def insert_proto_columns_data_using_dml(instance_id, database_id): """Inserts sample proto column data into the given database using a DML statement.""" spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) @@ -2529,21 +2600,25 @@ def insert_proto_columns_data_with_dml(instance_id, database_id): def insert_singers_with_proto_column(transaction): row_ct = transaction.execute_update( - "INSERT INTO SingersProto (SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray," + "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray," " SingerGenreArray) " " VALUES (1, 'Virginia', 'Watson', @singerInfo, @singerGenre, @singerInfoArray, @singerGenreArray)", params={ "singerInfo": singer_info, "singerGenre": singer_pb2.Genre.ROCK, "singerInfoArray": singer_info_array, - "singerGenreArray": singer_genre_array + "singerGenreArray": singer_genre_array, }, param_types={ "singerInfo": param_types.ProtoMessage(singer_info), "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), - "singerInfoArray": param_types.Array(param_types.ProtoMessage(singer_info)), - "singerGenreArray": param_types.Array(param_types.ProtoEnum(singer_pb2.Genre)) - } + "singerInfoArray": param_types.Array( + param_types.ProtoMessage(singer_info) + ), + "singerGenreArray": param_types.Array( + param_types.ProtoEnum(singer_pb2.Genre) + ), + }, ) print("{} record(s) inserted.".format(row_ct)) @@ -2551,44 +2626,7 @@ def insert_singers_with_proto_column(transaction): database.run_in_transaction(insert_singers_with_proto_column) -# [END spanner_insert_proto_columns_data_with_dml] - - -# [START spanner_insert_proto_columns_data] -def insert_proto_columns_data(instance_id, database_id): - """Inserts sample proto column data into the given database. - - The database and table must already exist and can be created using - `create_database`. - """ - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - singer_info = singer_pb2.SingerInfo() - singer_info.singer_id = 2 - singer_info.birth_date = "February" - singer_info.nationality = "Country2" - singer_info.genre = singer_pb2.Genre.FOLK - - singer_info_array = [singer_info] - singer_genre_array = [singer_pb2.Genre.FOLK] - - with database.batch() as batch: - batch.insert( - table="SingersProto", - columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", - "SingerGenreArray"), - values=[ - (2, "Marc", "Richards", singer_info, singer_pb2.Genre.ROCK, singer_info_array, singer_genre_array), - (3, "Catalina", "Smith", None, None, None, None), - ], - ) - - print("Inserted data.") - - -# [END spanner_insert_proto_columns_data] +# [END spanner_insert_proto_columns_data_using_dml] # [START spanner_read_proto_columns_data] @@ -2601,73 +2639,102 @@ def read_proto_columns_data(instance_id, database_id): with database.snapshot() as snapshot: keyset = spanner.KeySet(all_=True) results = snapshot.read( - table="SingersProto", - columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", "SingerGenreArray"), + table="Singers", + columns=( + "SingerId", + "FirstName", + "LastName", + "SingerInfo", + "SingerGenre", + "SingerInfoArray", + "SingerGenreArray", + ), keyset=keyset, - column_info={"SingerInfo": singer_pb2.SingerInfo(), - "SingerGenre": singer_pb2.Genre, - "SingerInfoArray": singer_pb2.SingerInfo(), - "SingerGenreArray": singer_pb2.Genre}, + column_info={ + "SingerInfo": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenreArray": singer_pb2.Genre, + }, ) for row in results: - print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " - "SingerGenreArray: {}".format(*row)) + print( + "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " + "SingerGenreArray: {}".format(*row) + ) # [END spanner_read_proto_columns_data] -# [START spanner_read_proto_columns_data_using_helper_method] -def read_proto_columns_data_using_helper_method(instance_id, database_id): - """Reads sample proto column data from the database.""" +# [START spanner_read_proto_columns_data_using_dql] +def read_proto_columns_data_using_dql(instance_id, database_id): + """Queries sample proto column data from the database using SQL.""" spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) with database.snapshot() as snapshot: - keyset = spanner.KeySet(all_=True) - results = snapshot.read( - table="SingersProto", - columns=("SingerId", "FirstName", "LastName", "SingerInfo", "SingerGenre", "SingerInfoArray", "SingerGenreArray"), - keyset=keyset, + results = snapshot.execute_sql( + "SELECT SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray, SingerGenreArray FROM Singers", + column_info={ + "SingerInfo": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenreArray": singer_pb2.Genre, + }, ) for row in results: - singer_info_proto_msg = get_proto_message(row[3], singer_pb2.SingerInfo()) - singer_genre_proto_enum = get_proto_enum(row[4], singer_pb2.Genre) - singer_info_list = get_proto_message(row[5], singer_pb2.SingerInfo()) - singer_genre_list = get_proto_enum(row[6], singer_pb2.Genre) - print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, SingerInfoArray: {}, " - "SingerGenreArray: {}".format(row[0], row[1], row[2], singer_info_proto_msg, singer_genre_proto_enum, - singer_info_list, singer_genre_list)) + print( + "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " + "SingerGenreArray: {}".format(*row) + ) -# [END spanner_read_proto_columns_data_using_helper_method] +# [END spanner_read_proto_columns_data_using_dql] -# [START spanner_query_proto_columns_data] -def query_proto_columns_data(instance_id, database_id): - """Queries sample proto column data from the database using SQL.""" +def read_proto_columns_data_using_helper_method(instance_id, database_id): + """Reads sample proto column data from the database.""" spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) with database.snapshot() as snapshot: - results = snapshot.execute_sql( - "SELECT SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray, SingerGenreArray FROM SingersProto", - column_info={"SingerInfo": singer_pb2.SingerInfo(), - "SingerGenre": singer_pb2.Genre, - "SingerInfoArray": singer_pb2.SingerInfo(), - "SingerGenreArray": singer_pb2.Genre}, + keyset = spanner.KeySet(all_=True) + results = snapshot.read( + table="Singers", + columns=( + "SingerId", + "FirstName", + "LastName", + "SingerInfo", + "SingerGenre", + "SingerInfoArray", + "SingerGenreArray", + ), + keyset=keyset, ) for row in results: - print("SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " - "SingerGenreArray: {}".format(*row)) - - -# [END spanner_query_proto_columns_data] + singer_info_proto_msg = get_proto_message(row[3], singer_pb2.SingerInfo()) + singer_genre_proto_enum = get_proto_enum(row[4], singer_pb2.Genre) + singer_info_list = get_proto_message(row[5], singer_pb2.SingerInfo()) + singer_genre_list = get_proto_enum(row[6], singer_pb2.Genre) + print( + "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, SingerInfoArray: {}, " + "SingerGenreArray: {}".format( + row[0], + row[1], + row[2], + singer_info_proto_msg, + singer_genre_proto_enum, + singer_info_list, + singer_genre_list, + ) + ) if __name__ == "__main__": # noqa: C901 @@ -2682,7 +2749,6 @@ def query_proto_columns_data(instance_id, database_id): subparsers = parser.add_subparsers(dest="command") subparsers.add_parser("create_instance", help=create_instance.__doc__) subparsers.add_parser("create_database", help=create_database.__doc__) - subparsers.add_parser("create_database_with_proto_descriptors", help=create_database_with_proto_descriptors.__doc__) subparsers.add_parser("get_database_ddl", help=get_database_ddl.__doc__) subparsers.add_parser("insert_data", help=insert_data.__doc__) subparsers.add_parser("delete_data", help=delete_data.__doc__) @@ -2788,13 +2854,28 @@ def query_proto_columns_data(instance_id, database_id): "read_data_with_database_role", help=read_data_with_database_role.__doc__ ) subparsers.add_parser("list_database_roles", help=list_database_roles.__doc__) - subparsers.add_parser("insert_proto_columns_data_with_dml", help=insert_proto_columns_data_with_dml.__doc__) - subparsers.add_parser("insert_proto_columns_data", help=insert_proto_columns_data.__doc__) - subparsers.add_parser("read_proto_columns_data", help=read_proto_columns_data.__doc__) subparsers.add_parser( - "read_proto_columns_data_using_helper_method", help=read_proto_columns_data_using_helper_method.__doc__ + "create_database_with_proto_descriptor", + help=create_database_with_proto_descriptor.__doc__, + ) + subparsers.add_parser( + "insert_proto_columns_data_using_dml", + help=insert_proto_columns_data_using_dml.__doc__, + ) + subparsers.add_parser( + "insert_proto_columns_data", help=insert_proto_columns_data.__doc__ + ) + subparsers.add_parser( + "read_proto_columns_data", help=read_proto_columns_data.__doc__ + ) + subparsers.add_parser( + "read_proto_columns_data_using_helper_method", + help=read_proto_columns_data_using_helper_method.__doc__, + ) + subparsers.add_parser( + "read_proto_columns_data_using_dql", + help=read_proto_columns_data_using_dql.__doc__, ) - subparsers.add_parser("query_proto_columns_data", help=query_proto_columns_data.__doc__) enable_fine_grained_access_parser = subparsers.add_parser( "enable_fine_grained_access", help=enable_fine_grained_access.__doc__ ) @@ -2812,8 +2893,6 @@ def query_proto_columns_data(instance_id, database_id): create_instance(args.instance_id) elif args.command == "create_database": create_database(args.instance_id, args.database_id) - elif args.command == "create_database_with_proto_descriptors": - create_database_with_proto_descriptors(args.instance_id, args.database_id) elif args.command == "get_database_ddl": get_database_ddl(args.instance_id, args.database_id) elif args.command == "insert_data": @@ -2938,13 +3017,15 @@ def query_proto_columns_data(instance_id, database_id): args.database_role, args.title, ) - elif args.command == "insert_proto_columns_data_with_dml": - insert_proto_columns_data_with_dml(args.instance_id, args.database_id) + elif args.command == "create_database_with_proto_descriptor": + create_database_with_proto_descriptor(args.instance_id, args.database_id) + elif args.command == "insert_proto_columns_data_using_dml": + insert_proto_columns_data_using_dml(args.instance_id, args.database_id) elif args.command == "insert_proto_columns_data": insert_proto_columns_data(args.instance_id, args.database_id) elif args.command == "read_proto_columns_data": read_proto_columns_data(args.instance_id, args.database_id) elif args.command == "read_proto_columns_data_using_helper_method": read_proto_columns_data_using_helper_method(args.instance_id, args.database_id) - elif args.command == "query_proto_columns_data": - query_proto_columns_data(args.instance_id, args.database_id) + elif args.command == "read_proto_columns_data_using_dql": + read_proto_columns_data_using_dql(args.instance_id, args.database_id) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index af80e0b535..90834d5339 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -45,13 +45,13 @@ """ CREATE_TABLE_SINGERS_PROTO = """\ -CREATE TABLE SingersProto ( -SingerId INT64 NOT NULL, -FirstName STRING(1024), -LastName STRING(1024), -SingerInfo spanner.examples.music.SingerInfo, -SingerGenre spanner.examples.music.Genre, -SingerInfoArray ARRAY, +CREATE TABLE Singers ( +SingerId INT64 NOT NULL, +FirstName STRING(1024), +LastName STRING(1024), +SingerInfo spanner.examples.music.SingerInfo, +SingerGenre spanner.examples.music.Genre, +SingerInfoArray ARRAY, SingerGenreArray ARRAY, ) PRIMARY KEY (SingerId) """ @@ -119,7 +119,16 @@ def database_ddl(): Sample testcase modules can override as needed. """ - return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS, CREATE_PROTO_BUNDLE, CREATE_TABLE_SINGERS_PROTO] + return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS] + + +@pytest.fixture(scope="module") +def database_ddl_for_proto_columns(): + """Sequence of DDL statements used to set up the database for proto columns. + + Sample testcase modules can override as needed. + """ + return [CREATE_PROTO_BUNDLE, CREATE_TABLE_SINGERS_PROTO] @pytest.fixture(scope="module") @@ -184,8 +193,8 @@ def test_create_database_with_encryption_config( assert kms_key_name in out -def test_create_database_with_proto_descriptors(capsys, instance_id, database_id): - snippets.create_database_with_proto_descriptors(instance_id, database_id) +def test_create_database_with_proto_descriptor(capsys, instance_id, database_id): + snippets.create_database_with_proto_descriptor(instance_id, database_id) out, _ = capsys.readouterr() assert database_id in out assert instance_id in out @@ -809,23 +818,37 @@ def test_list_database_roles(capsys, instance_id, sample_database): assert "new_parent" in out +def test_update_database_with_proto_descriptor(capsys, sample_instance, create_database_id): + # We have to create a new database here as proto samples also have Singers table and this will clash. + sample_instance.database(create_database_id).create().result(240) + snippets.update_database_with_proto_descriptor(sample_instance.instance_id, create_database_id) + out, _ = capsys.readouterr() + assert "updated with proto descriptors" in out + database = sample_instance.database(create_database_id) + database.drop() + + @pytest.mark.dependency(name="insert_proto_columns_data_dml") -def test_insert_proto_columns_data_with_dml(capsys, instance_id, sample_database): - snippets.insert_proto_columns_data_with_dml(instance_id, sample_database.database_id) +def test_insert_proto_columns_data_using_dml(capsys, instance_id, sample_database_for_proto_columns): + snippets.insert_proto_columns_data_using_dml( + instance_id, sample_database_for_proto_columns.database_id + ) out, _ = capsys.readouterr() assert "record(s) inserted" in out @pytest.mark.dependency(name="insert_proto_columns_data") -def test_insert_proto_columns_data(capsys, instance_id, sample_database): - snippets.insert_proto_columns_data(instance_id, sample_database.database_id) +def test_insert_proto_columns_data(capsys, instance_id, sample_database_for_proto_columns): + snippets.insert_proto_columns_data(instance_id, sample_database_for_proto_columns.database_id) out, _ = capsys.readouterr() assert "Inserted data" in out -@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) -def test_query_proto_columns_data(capsys, instance_id, sample_database): - snippets.query_proto_columns_data(instance_id, sample_database.database_id) +@pytest.mark.dependency( + depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] +) +def test_read_proto_columns_data_using_dql(capsys, instance_id, sample_database_for_proto_columns): + snippets.read_proto_columns_data_using_dql(instance_id, sample_database_for_proto_columns.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out @@ -833,9 +856,11 @@ def test_query_proto_columns_data(capsys, instance_id, sample_database): assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out -@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) -def test_read_proto_columns_data(capsys, instance_id, sample_database): - snippets.read_proto_columns_data(instance_id, sample_database.database_id) +@pytest.mark.dependency( + depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] +) +def test_read_proto_columns_data(capsys, instance_id, sample_database_for_proto_columns): + snippets.read_proto_columns_data(instance_id, sample_database_for_proto_columns.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out @@ -843,9 +868,15 @@ def test_read_proto_columns_data(capsys, instance_id, sample_database): assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out -@pytest.mark.dependency(depends=["insert_proto_columns_data_dml, insert_proto_columns_data"]) -def test_read_proto_columns_data_using_helper_method(capsys, instance_id, sample_database): - snippets.read_proto_columns_data_using_helper_method(instance_id, sample_database.database_id) +@pytest.mark.dependency( + depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] +) +def test_read_proto_columns_data_using_helper_method( + capsys, instance_id, sample_database_for_proto_columns +): + snippets.read_proto_columns_data_using_helper_method( + instance_id, sample_database_for_proto_columns.database_id + ) out, _ = capsys.readouterr() assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out diff --git a/setup.py b/setup.py index 86f2203d20..650b452838 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ "grpc-google-iam-v1 >= 0.12.4, <1.0.0dev", "proto-plus >= 1.22.0, <2.0.0dev", "sqlparse >= 0.3.0", - "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", + "protobuf>=3.20.2,<5.0.0dev,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", ] extras = { "tracing": [ diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index e061a1eadf..cd64ca21f9 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -13,4 +13,4 @@ sqlparse==0.3.0 opentelemetry-api==1.1.0 opentelemetry-sdk==1.1.0 opentelemetry-instrumentation==0.20b0 -protobuf==3.19.5 +protobuf==3.20.2 diff --git a/tests/_fixtures.py b/tests/_fixtures.py index 0bd8fe163a..62616c6969 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -28,6 +28,10 @@ phone_number STRING(1024) ) PRIMARY KEY (contact_id, phone_type), INTERLEAVE IN PARENT contacts ON DELETE CASCADE; +CREATE PROTO BUNDLE ( + spanner.examples.music.SingerInfo, + spanner.examples.music.Genre, + ); CREATE TABLE all_types ( pkey INT64 NOT NULL, int_value INT64, @@ -48,6 +52,10 @@ numeric_array ARRAY, json_value JSON, json_array ARRAY, + proto_message_value spanner.examples.music.SingerInfo, + proto_message_array ARRAY, + proto_enum_value spanner.examples.music.Genre, + proto_enum_array ARRAY, ) PRIMARY KEY (pkey); CREATE TABLE counters ( @@ -159,8 +167,22 @@ CREATE INDEX name ON contacts(first_name, last_name); """ +PROTO_COLUMNS_DDL = """\ +CREATE TABLE singers ( + singer_id INT64 NOT NULL, + first_name STRING(1024), + last_name STRING(1024), + singer_info spanner.examples.music.SingerInfo, + singer_genre spanner.examples.music.Genre, ) + PRIMARY KEY (singer_id); +CREATE INDEX SingerByGenre ON singers(singer_genre) STORING (first_name, last_name); +""" + DDL_STATEMENTS = [stmt.strip() for stmt in DDL.split(";") if stmt.strip()] EMULATOR_DDL_STATEMENTS = [ stmt.strip() for stmt in EMULATOR_DDL.split(";") if stmt.strip() ] PG_DDL_STATEMENTS = [stmt.strip() for stmt in PG_DDL.split(";") if stmt.strip()] +PROTO_COLUMNS_DDL_STATEMENTS = [ + stmt.strip() for stmt in PROTO_COLUMNS_DDL.split(";") if stmt.strip() +] diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index 60926b216e..b62d453512 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -65,6 +65,8 @@ ) ) +PROTO_COLUMNS_DDL_STATEMENTS = _fixtures.PROTO_COLUMNS_DDL_STATEMENTS + retry_true = retry.RetryResult(operator.truth) retry_false = retry.RetryResult(operator.not_) diff --git a/tests/system/_sample_data.py b/tests/system/_sample_data.py index a7f3b80a86..f7f23fc5d2 100644 --- a/tests/system/_sample_data.py +++ b/tests/system/_sample_data.py @@ -18,7 +18,7 @@ from google.api_core import datetime_helpers from google.cloud._helpers import UTC from google.cloud import spanner_v1 - +from samples.samples.testdata import singer_pb2 TABLE = "contacts" COLUMNS = ("contact_id", "first_name", "last_name", "email") @@ -33,6 +33,31 @@ COUNTERS_TABLE = "counters" COUNTERS_COLUMNS = ("name", "value") +SINGERS_PROTO_TABLE = "singers" +SINGERS_PROTO_COLUMNS = ( + "singer_id", + "first_name", + "last_name", + "singer_info", + "singer_genre", +) +SINGER_INFO_1 = singer_pb2.SingerInfo() +SINGER_GENRE_1 = singer_pb2.Genre.ROCK +SINGER_INFO_1.singer_id = 1 +SINGER_INFO_1.birth_date = "January" +SINGER_INFO_1.nationality = "Country1" +SINGER_INFO_1.genre = SINGER_GENRE_1 +SINGER_INFO_2 = singer_pb2.SingerInfo() +SINGER_GENRE_2 = singer_pb2.Genre.FOLK +SINGER_INFO_2.singer_id = 2 +SINGER_INFO_2.birth_date = "February" +SINGER_INFO_2.nationality = "Country2" +SINGER_INFO_2.genre = SINGER_GENRE_2 +SINGERS_PROTO_ROW_DATA = ( + (1, "Singer1", "Singer1", SINGER_INFO_1, SINGER_GENRE_1), + (2, "Singer2", "Singer2", SINGER_INFO_2, SINGER_GENRE_2), +) + def _assert_timestamp(value, nano_value): assert isinstance(value, datetime.datetime) diff --git a/tests/system/conftest.py b/tests/system/conftest.py index fdeab14c8f..62b06019f5 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -74,6 +74,17 @@ def database_dialect(): ) +@pytest.fixture(scope="session") +def proto_descriptor_file(): + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + file = open(filename, "rb") + yield file.read() + file.close() + + @pytest.fixture(scope="session") def spanner_client(): if _helpers.USE_EMULATOR: @@ -177,7 +188,9 @@ def shared_instance( @pytest.fixture(scope="session") -def shared_database(shared_instance, database_operation_timeout, database_dialect): +def shared_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_database") pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: @@ -198,6 +211,7 @@ def shared_database(shared_instance, database_operation_timeout, database_dialec ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = database.create() operation.result(database_operation_timeout) # raises on failure / timeout. diff --git a/tests/system/test_backup_api.py b/tests/system/test_backup_api.py index dc80653786..6ffc74283e 100644 --- a/tests/system/test_backup_api.py +++ b/tests/system/test_backup_api.py @@ -94,7 +94,9 @@ def database_version_time(shared_database): @pytest.fixture(scope="session") -def second_database(shared_instance, database_operation_timeout, database_dialect): +def second_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_database2") pool = spanner_v1.BurstyPool(labels={"testcase": "database_api"}) if database_dialect == DatabaseDialect.POSTGRESQL: @@ -115,6 +117,7 @@ def second_database(shared_instance, database_operation_timeout, database_dialec ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = database.create() operation.result(database_operation_timeout) # raises on failure / timeout. diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 364c159da5..2108667c7e 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import time import uuid @@ -75,7 +74,11 @@ def test_create_database(shared_instance, databases_to_delete, database_dialect) def test_database_binding_of_fixed_size_pool( - not_emulator, shared_instance, databases_to_delete, not_postgres + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, + proto_descriptor_file, ): temp_db_id = _helpers.unique_id("fixed_size_db", separator="_") temp_db = shared_instance.database(temp_db_id) @@ -89,7 +92,9 @@ def test_database_binding_of_fixed_size_pool( "CREATE ROLE parent", "GRANT SELECT ON TABLE contacts TO ROLE parent", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. pool = FixedSizePool( @@ -102,7 +107,11 @@ def test_database_binding_of_fixed_size_pool( def test_database_binding_of_pinging_pool( - not_emulator, shared_instance, databases_to_delete, not_postgres + not_emulator, + shared_instance, + databases_to_delete, + not_postgres, + proto_descriptor_file, ): temp_db_id = _helpers.unique_id("binding_db", separator="_") temp_db = shared_instance.database(temp_db_id) @@ -116,7 +125,9 @@ def test_database_binding_of_pinging_pool( "CREATE ROLE parent", "GRANT SELECT ON TABLE contacts TO ROLE parent", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. pool = PingingPool( @@ -291,7 +302,7 @@ def test_table_not_found(shared_instance): def test_update_ddl_w_operation_id( - shared_instance, databases_to_delete, database_dialect + shared_instance, databases_to_delete, database_dialect, proto_descriptor_file ): # We used to have: # @pytest.mark.skip( @@ -309,7 +320,11 @@ def test_update_ddl_w_operation_id( # random but shortish always start with letter operation_id = f"a{str(uuid.uuid4())[:8]}" - operation = temp_db.update_ddl(_helpers.DDL_STATEMENTS, operation_id=operation_id) + operation = temp_db.update_ddl( + _helpers.DDL_STATEMENTS, + operation_id=operation_id, + proto_descriptors=proto_descriptor_file, + ) assert operation_id == operation.operation.name.split("/")[-1] @@ -325,6 +340,7 @@ def test_update_ddl_w_pitr_invalid( not_postgres, shared_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") @@ -342,7 +358,7 @@ def test_update_ddl_w_pitr_invalid( f" SET OPTIONS (version_retention_period = '{retention_period}')" ] with pytest.raises(exceptions.InvalidArgument): - temp_db.update_ddl(ddl_statements) + temp_db.update_ddl(ddl_statements, proto_descriptors=proto_descriptor_file) def test_update_ddl_w_pitr_success( @@ -350,6 +366,7 @@ def test_update_ddl_w_pitr_success( not_postgres, shared_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool(labels={"testcase": "update_database_ddl_pitr"}) temp_db_id = _helpers.unique_id("pitr_upd_ddl_inv", separator="_") @@ -366,7 +383,9 @@ def test_update_ddl_w_pitr_success( f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (version_retention_period = '{retention_period}')" ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. temp_db.reload() @@ -379,6 +398,7 @@ def test_update_ddl_w_default_leader_success( not_postgres, multiregion_instance, databases_to_delete, + proto_descriptor_file, ): pool = spanner_v1.BurstyPool( labels={"testcase": "update_database_ddl_default_leader"}, @@ -398,7 +418,9 @@ def test_update_ddl_w_default_leader_success( f"ALTER DATABASE {temp_db_id}" f" SET OPTIONS (default_leader = '{default_leader}')" ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. temp_db.reload() @@ -411,6 +433,7 @@ def test_create_role_grant_access_success( shared_instance, databases_to_delete, not_postgres, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -428,7 +451,9 @@ def test_create_role_grant_access_success( f"CREATE ROLE {creator_role_orphan}", f"GRANT SELECT ON TABLE contacts TO ROLE {creator_role_parent}", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # Perform select with orphan role on table contacts. @@ -460,6 +485,7 @@ def test_list_database_role_success( shared_instance, databases_to_delete, not_postgres, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -476,7 +502,9 @@ def test_list_database_role_success( f"CREATE ROLE {creator_role_parent}", f"CREATE ROLE {creator_role_orphan}", ] - operation = temp_db.update_ddl(ddl_statements) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # List database roles. @@ -562,3 +590,30 @@ def _unit_of_work(transaction, name): rows = list(after.read(sd.COUNTERS_TABLE, sd.COUNTERS_COLUMNS, sd.ALL)) assert len(rows) == 2 + + +def test_create_table_with_proto_columns( + not_emulator, + not_postgres, + shared_instance, + databases_to_delete, + proto_descriptor_file, +): + proto_cols_db_id = _helpers.unique_id("proto-columns") + extra_ddl = [ + "CREATE PROTO BUNDLE (spanner.examples.music.SingerInfo, spanner.examples.music.Genre,)" + ] + + proto_cols_database = shared_instance.database( + proto_cols_db_id, + ddl_statements=extra_ddl + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, + proto_descriptors=proto_descriptor_file, + ) + operation = proto_cols_database.create() + operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. + + databases_to_delete.append(proto_cols_database) + + proto_cols_database.reload() + assert proto_cols_database.proto_descriptors is not None + assert any("PROTO BUNDLE" in stmt for stmt in proto_cols_database.ddl_statements) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 6b7afbe525..8b00073567 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import base64 import collections import datetime import decimal @@ -29,6 +29,7 @@ from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud._helpers import UTC from google.cloud.spanner_v1.data_types import JsonObject +from samples.samples.testdata import singer_pb2 from tests import _helpers as ot_helpers from . import _helpers from . import _sample_data @@ -57,6 +58,8 @@ JSON_2 = JsonObject( {"sample_object": {"name": "Anamika", "id": 2635}}, ) +SINGER_INFO = _sample_data.SINGER_INFO_1 +SINGER_GENRE = _sample_data.SINGER_GENRE_1 COUNTERS_TABLE = "counters" COUNTERS_COLUMNS = ("name", "value") @@ -81,6 +84,10 @@ "numeric_array", "json_value", "json_array", + "proto_message_value", + "proto_message_array", + "proto_enum_value", + "proto_enum_array", ) EMULATOR_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:-4] @@ -120,6 +127,8 @@ AllTypesRowData(pkey=109, numeric_value=NUMERIC_1), AllTypesRowData(pkey=110, json_value=JSON_1), AllTypesRowData(pkey=111, json_value=JsonObject([JSON_1, JSON_2])), + AllTypesRowData(pkey=112, proto_message_value=SINGER_INFO), + AllTypesRowData(pkey=113, proto_enum_value=SINGER_GENRE), # empty array values AllTypesRowData(pkey=201, int_array=[]), AllTypesRowData(pkey=202, bool_array=[]), @@ -130,6 +139,8 @@ AllTypesRowData(pkey=207, timestamp_array=[]), AllTypesRowData(pkey=208, numeric_array=[]), AllTypesRowData(pkey=209, json_array=[]), + AllTypesRowData(pkey=210, proto_message_array=[]), + AllTypesRowData(pkey=211, proto_enum_array=[]), # non-empty array values, including nulls AllTypesRowData(pkey=301, int_array=[123, 456, None]), AllTypesRowData(pkey=302, bool_array=[True, False, None]), @@ -142,6 +153,8 @@ AllTypesRowData(pkey=307, timestamp_array=[SOME_TIME, NANO_TIME, None]), AllTypesRowData(pkey=308, numeric_array=[NUMERIC_1, NUMERIC_2, None]), AllTypesRowData(pkey=309, json_array=[JSON_1, JSON_2, None]), + AllTypesRowData(pkey=310, proto_message_array=[SINGER_INFO, None]), + AllTypesRowData(pkey=311, proto_enum_array=[SINGER_GENRE, None]), ) EMULATOR_ALL_TYPES_ROWDATA = ( # all nulls @@ -221,9 +234,16 @@ ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS ALL_TYPES_ROWDATA = LIVE_ALL_TYPES_ROWDATA +COLUMN_INFO = { + "proto_message_value": singer_pb2.SingerInfo(), + "proto_message_array": singer_pb2.SingerInfo(), +} + @pytest.fixture(scope="session") -def sessions_database(shared_instance, database_operation_timeout, database_dialect): +def sessions_database( + shared_instance, database_operation_timeout, database_dialect, proto_descriptor_file +): database_name = _helpers.unique_id("test_sessions", separator="_") pool = spanner_v1.BurstyPool(labels={"testcase": "session_api"}) @@ -245,6 +265,7 @@ def sessions_database(shared_instance, database_operation_timeout, database_dial database_name, ddl_statements=_helpers.DDL_STATEMENTS, pool=pool, + proto_descriptors=proto_descriptor_file, ) operation = sessions_database.create() @@ -459,7 +480,11 @@ def test_batch_insert_then_read_all_datatypes(sessions_database): batch.insert(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, ALL_TYPES_ROWDATA) with sessions_database.snapshot(read_timestamp=batch.committed) as snapshot: - rows = list(snapshot.read(ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, sd.ALL)) + rows = list( + snapshot.read( + ALL_TYPES_TABLE, ALL_TYPES_COLUMNS, sd.ALL, column_info=COLUMN_INFO + ) + ) sd._check_rows_data(rows, expected=ALL_TYPES_ROWDATA) @@ -1315,6 +1340,21 @@ def _unit_of_work(transaction): return committed +def _set_up_proto_table(database): + + sd = _sample_data + + def _unit_of_work(transaction): + transaction.delete(sd.SINGERS_PROTO_TABLE, sd.ALL) + transaction.insert( + sd.SINGERS_PROTO_TABLE, sd.SINGERS_PROTO_COLUMNS, sd.SINGERS_PROTO_ROW_DATA + ) + + committed = database.run_in_transaction(_unit_of_work) + + return committed + + def test_read_with_single_keys_index(sessions_database): # [START spanner_test_single_key_index_read] sd = _sample_data @@ -1464,7 +1504,11 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): def test_read_w_index( - shared_instance, database_operation_timeout, databases_to_delete, database_dialect + shared_instance, + database_operation_timeout, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): # Indexed reads cannot return non-indexed columns sd = _sample_data @@ -1492,9 +1536,12 @@ def test_read_w_index( else: temp_db = shared_instance.database( _helpers.unique_id("test_read", separator="_"), - ddl_statements=_helpers.DDL_STATEMENTS + extra_ddl, + ddl_statements=_helpers.DDL_STATEMENTS + + extra_ddl + + _helpers.PROTO_COLUMNS_DDL_STATEMENTS, pool=pool, database_dialect=database_dialect, + proto_descriptors=proto_descriptor_file, ) operation = temp_db.create() operation.result(database_operation_timeout) # raises on failure / timeout. @@ -1510,6 +1557,28 @@ def test_read_w_index( expected = list(reversed([(row[0], row[2]) for row in _row_data(row_count)])) sd._check_rows_data(rows, expected) + # Test indexes on proto column types + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + # Indexed reads cannot return non-indexed columns + my_columns = ( + sd.SINGERS_PROTO_COLUMNS[0], + sd.SINGERS_PROTO_COLUMNS[1], + sd.SINGERS_PROTO_COLUMNS[4], + ) + committed = _set_up_proto_table(temp_db) + with temp_db.snapshot(read_timestamp=committed) as snapshot: + rows = list( + snapshot.read( + sd.SINGERS_PROTO_TABLE, + my_columns, + spanner_v1.KeySet(keys=[[singer_pb2.Genre.ROCK]]), + index="SingerByGenre", + ) + ) + row = sd.SINGERS_PROTO_ROW_DATA[0] + expected = list([(row[0], row[1], row[4])]) + sd._check_rows_data(rows, expected) + def test_read_w_single_key(sessions_database): # [START spanner_test_single_key_read] @@ -1922,12 +1991,17 @@ def _check_sql_results( expected, order=True, recurse_into_lists=True, + column_info=None, ): if order and "ORDER" not in sql: sql += " ORDER BY pkey" with database.snapshot() as snapshot: - rows = list(snapshot.execute_sql(sql, params=params, param_types=param_types)) + rows = list( + snapshot.execute_sql( + sql, params=params, param_types=param_types, column_info=column_info + ) + ) _sample_data._check_rows_data( rows, expected=expected, recurse_into_lists=recurse_into_lists @@ -2023,32 +2097,39 @@ def _bind_test_helper( array_value, expected_array_value=None, recurse_into_lists=True, + column_info=None, + expected_single_value=None, ): database.snapshot(multi_use=True) key = "p1" if database_dialect == DatabaseDialect.POSTGRESQL else "v" placeholder = "$1" if database_dialect == DatabaseDialect.POSTGRESQL else f"@{key}" + if expected_single_value is None: + expected_single_value = single_value + # Bind a non-null _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: single_value}, param_types={key: param_type}, - expected=[(single_value,)], + expected=[(expected_single_value,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind a null _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: param_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind an array of @@ -2062,34 +2143,37 @@ def _bind_test_helper( _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: array_value}, param_types={key: array_type}, expected=[(expected_array_value,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind an empty array of _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: []}, param_types={key: array_type}, expected=[([],)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) # Bind a null array of _check_sql_results( database, - sql=f"SELECT {placeholder}", + sql=f"SELECT {placeholder} as column", params={key: None}, param_types={key: array_type}, expected=[(None,)], order=False, recurse_into_lists=recurse_into_lists, + column_info=column_info, ) @@ -2457,6 +2541,80 @@ def test_execute_sql_w_query_param_struct(sessions_database, not_postgres): ) +def test_execute_sql_w_proto_message_bindings( + not_emulator, not_postgres, sessions_database, database_dialect +): + singer_info = _sample_data.SINGER_INFO_1 + singer_info_bytes = base64.b64encode(singer_info.SerializeToString()) + + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoMessage(singer_info), + singer_info, + [singer_info, None], + column_info={"column": singer_pb2.SingerInfo()}, + ) + + # Tests compatibility between proto message and bytes column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoMessage(singer_info), + singer_info_bytes, + [singer_info_bytes, None], + expected_single_value=singer_info, + expected_array_value=[singer_info, None], + column_info={"column": singer_pb2.SingerInfo()}, + ) + + # Tests compatibility between proto message and bytes column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.BYTES, + singer_info, + [singer_info, None], + expected_single_value=singer_info_bytes, + expected_array_value=[singer_info_bytes, None], + ) + + +def test_execute_sql_w_proto_enum_bindings( + not_emulator, not_postgres, sessions_database, database_dialect +): + singer_genre = _sample_data.SINGER_GENRE_1 + + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), + singer_genre, + [singer_genre, None], + ) + + # Tests compatibility between proto enum and int64 column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.ProtoEnum(singer_pb2.Genre), + 3, + [3, None], + expected_single_value="ROCK", + expected_array_value=["ROCK", None], + column_info={"column": singer_pb2.Genre}, + ) + + # Tests compatibility between proto enum and int64 column types + _bind_test_helper( + sessions_database, + database_dialect, + spanner_v1.param_types.INT64, + singer_genre, + [singer_genre, None], + ) + + def test_execute_sql_returning_transfinite_floats(sessions_database, not_postgres): with sessions_database.snapshot(multi_use=True) as snapshot: diff --git a/tests/system/testdata/descriptors.pb b/tests/system/testdata/descriptors.pb new file mode 100644 index 0000000000000000000000000000000000000000..3ebb79420b3ffd2ca3b3b57433a4a10bfa22b675 GIT binary patch literal 251 zcmd=3!N|o^oSB!NTBKJ{lwXoBBvxFIn3o6SrdA~87UZNB>*bafXC^DnXXN4v1}pT; zOUoCM=Hi5Ci_c7vU{qk#U=HGd2zaIl$#QWeWfqlW#HS>dq)IRWWjTX5!6Gg|0U-r0 z?!3g3%>2B>oXnC+31+Z7vXGE57i)TIUQwz93s8>FNLCNKqx9TCih>|&we+}H!F(Zh lF6IFL009Oe4lWii$EYX)Mi9%*-^W{k3B(HWclH)w1^^+RM@0Yt literal 0 HcmV?d00001 diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 21434da191..b695f42564 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -313,6 +313,25 @@ def test_w_json_None(self): value_pb = self._callFUT(value) self.assertTrue(value_pb.HasField("null_value")) + def test_w_proto_message(self): + from google.protobuf.struct_pb2 import Value + import base64 + from samples.samples.testdata import singer_pb2 + + singer_info = singer_pb2.SingerInfo() + expected = Value(string_value=base64.b64encode(singer_info.SerializeToString())) + value_pb = self._callFUT(singer_info) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb, expected) + + def test_w_proto_enum(self): + from google.protobuf.struct_pb2 import Value + from samples.samples.testdata import singer_pb2 + + value_pb = self._callFUT(singer_pb2.Genre.ROCK) + self.assertIsInstance(value_pb, Value) + self.assertEqual(value_pb.string_value, "3") + class Test_make_list_value_pb(unittest.TestCase): def _callFUT(self, *args, **kw): @@ -394,9 +413,10 @@ def test_w_null(self): from google.cloud.spanner_v1 import TypeCode field_type = Type(code=TypeCode.STRING) + field_name = "null_column" value_pb = Value(null_value=NULL_VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), None) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), None) def test_w_string(self): from google.protobuf.struct_pb2 import Value @@ -405,9 +425,10 @@ def test_w_string(self): VALUE = "Value" field_type = Type(code=TypeCode.STRING) + field_name = "string_column" value_pb = Value(string_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_bytes(self): from google.protobuf.struct_pb2 import Value @@ -416,9 +437,10 @@ def test_w_bytes(self): VALUE = b"Value" field_type = Type(code=TypeCode.BYTES) + field_name = "bytes_column" value_pb = Value(string_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_bool(self): from google.protobuf.struct_pb2 import Value @@ -427,9 +449,10 @@ def test_w_bool(self): VALUE = True field_type = Type(code=TypeCode.BOOL) + field_name = "bool_column" value_pb = Value(bool_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_int(self): from google.protobuf.struct_pb2 import Value @@ -438,9 +461,10 @@ def test_w_int(self): VALUE = 12345 field_type = Type(code=TypeCode.INT64) + field_name = "int_column" value_pb = Value(string_value=str(VALUE)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float(self): from google.protobuf.struct_pb2 import Value @@ -449,9 +473,10 @@ def test_w_float(self): VALUE = 3.14159 field_type = Type(code=TypeCode.FLOAT64) + field_name = "float_column" value_pb = Value(number_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float_str(self): from google.protobuf.struct_pb2 import Value @@ -460,10 +485,13 @@ def test_w_float_str(self): VALUE = "3.14159" field_type = Type(code=TypeCode.FLOAT64) + field_name = "float_str_column" value_pb = Value(string_value=VALUE) expected_value = 3.14159 - self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + self.assertEqual( + self._callFUT(value_pb, field_type, field_name), expected_value + ) def test_w_date(self): import datetime @@ -473,9 +501,10 @@ def test_w_date(self): VALUE = datetime.date.today() field_type = Type(code=TypeCode.DATE) + field_name = "date_column" value_pb = Value(string_value=VALUE.isoformat()) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_timestamp_wo_nanos(self): import datetime @@ -488,9 +517,10 @@ def test_w_timestamp_wo_nanos(self): 2016, 12, 20, 21, 13, 47, microsecond=123456, tzinfo=datetime.timezone.utc ) field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "nanos_column" value_pb = Value(string_value=datetime_helpers.to_rfc3339(value)) - parsed = self._callFUT(value_pb, field_type) + parsed = self._callFUT(value_pb, field_type, field_name) self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) @@ -505,9 +535,10 @@ def test_w_timestamp_w_nanos(self): 2016, 12, 20, 21, 13, 47, nanosecond=123456789, tzinfo=datetime.timezone.utc ) field_type = Type(code=TypeCode.TIMESTAMP) + field_name = "timestamp_column" value_pb = Value(string_value=datetime_helpers.to_rfc3339(value)) - parsed = self._callFUT(value_pb, field_type) + parsed = self._callFUT(value_pb, field_type, field_name) self.assertIsInstance(parsed, datetime_helpers.DatetimeWithNanoseconds) self.assertEqual(parsed, value) @@ -519,9 +550,10 @@ def test_w_array_empty(self): field_type = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) ) + field_name = "array_empty_column" value_pb = Value(list_value=ListValue(values=[])) - self.assertEqual(self._callFUT(value_pb, field_type), []) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), []) def test_w_array_non_empty(self): from google.protobuf.struct_pb2 import Value, ListValue @@ -531,13 +563,14 @@ def test_w_array_non_empty(self): field_type = Type( code=TypeCode.ARRAY, array_element_type=Type(code=TypeCode.INT64) ) + field_name = "array_non_empty_column" VALUES = [32, 19, 5] values_pb = ListValue( values=[Value(string_value=str(value)) for value in VALUES] ) value_pb = Value(list_value=values_pb) - self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUES) def test_w_struct(self): from google.protobuf.struct_pb2 import Value @@ -554,9 +587,10 @@ def test_w_struct(self): ] ) field_type = Type(code=TypeCode.STRUCT, struct_type=struct_type_pb) + field_name = "struct_column" value_pb = Value(list_value=_make_list_value_pb(VALUES)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUES) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUES) def test_w_numeric(self): import decimal @@ -566,9 +600,10 @@ def test_w_numeric(self): VALUE = decimal.Decimal("99999999999999999999999999999.999999999") field_type = Type(code=TypeCode.NUMERIC) + field_name = "numeric_column" value_pb = Value(string_value=str(VALUE)) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_json(self): import json @@ -580,9 +615,10 @@ def test_w_json(self): str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":")) field_type = Type(code=TypeCode.JSON) + field_name = "json_column" value_pb = Value(string_value=str_repr) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) VALUE = None str_repr = json.dumps(VALUE, sort_keys=True, separators=(",", ":")) @@ -590,7 +626,7 @@ def test_w_json(self): field_type = Type(code=TypeCode.JSON) value_pb = Value(string_value=str_repr) - self.assertEqual(self._callFUT(value_pb, field_type), {}) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), {}) def test_w_unknown_type(self): from google.protobuf.struct_pb2 import Value @@ -598,10 +634,44 @@ def test_w_unknown_type(self): from google.cloud.spanner_v1 import TypeCode field_type = Type(code=TypeCode.TYPE_CODE_UNSPECIFIED) + field_name = "unknown_column" value_pb = Value(string_value="Borked") with self.assertRaises(ValueError): - self._callFUT(value_pb, field_type) + self._callFUT(value_pb, field_type, field_name) + + def test_w_proto_message(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + import base64 + from samples.samples.testdata import singer_pb2 + + VALUE = singer_pb2.SingerInfo() + field_type = Type(code=TypeCode.PROTO) + field_name = "proto_message_column" + value_pb = Value(string_value=base64.b64encode(VALUE.SerializeToString())) + column_info = {"proto_message_column": singer_pb2.SingerInfo()} + + self.assertEqual( + self._callFUT(value_pb, field_type, field_name, column_info), VALUE + ) + + def test_w_proto_enum(self): + from google.protobuf.struct_pb2 import Value + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from samples.samples.testdata import singer_pb2 + + VALUE = "ROCK" + field_type = Type(code=TypeCode.ENUM) + field_name = "proto_enum_column" + value_pb = Value(string_value=str(singer_pb2.Genre.ROCK)) + column_info = {"proto_enum_column": singer_pb2.Genre} + + self.assertEqual( + self._callFUT(value_pb, field_type, field_name, column_info), VALUE + ) class Test_parse_list_value_pbs(unittest.TestCase): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index bff89320c7..dbff6c5107 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -186,6 +186,14 @@ def test_ctor_w_encryption_config(self): self.assertIs(database._instance, instance) self.assertEqual(database._encryption_config, encryption_config) + def test_ctor_w_proto_descriptors(self): + + instance = _Instance(self.INSTANCE_NAME) + database = self._make_one(self.DATABASE_ID, instance, proto_descriptors=b"") + self.assertEqual(database.database_id, self.DATABASE_ID) + self.assertIs(database._instance, instance) + self.assertEqual(database._proto_descriptors, b"") + def test_from_pb_bad_database_name(self): from google.cloud.spanner_admin_database_v1 import Database @@ -351,6 +359,15 @@ def test_default_leader(self): default_leader = database._default_leader = "us-east4" self.assertEqual(database.default_leader, default_leader) + def test_proto_descriptors(self): + + instance = _Instance(self.INSTANCE_NAME) + pool = _Pool() + database = self._make_one( + self.DATABASE_ID, instance, pool=pool, proto_descriptors=b"" + ) + self.assertEqual(database.proto_descriptors, b"") + def test_spanner_api_property_w_scopeless_creds(self): client = _Client() @@ -622,6 +639,41 @@ def test_create_success_w_encryption_config_dict(self): metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_create_success_w_proto_descriptors(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import CreateDatabaseRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.create_database.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + proto_descriptors = b"" + database = self._make_one( + self.DATABASE_ID, + instance, + ddl_statements=DDL_STATEMENTS, + pool=pool, + proto_descriptors=proto_descriptors, + ) + + future = database.create() + + self.assertIs(future, op_future) + + expected_request = CreateDatabaseRequest( + parent=self.INSTANCE_NAME, + create_statement="CREATE DATABASE {}".format(self.DATABASE_ID), + extra_statements=DDL_STATEMENTS, + proto_descriptors=proto_descriptors, + ) + + api.create_database.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_exists_grpc_error(self): from google.api_core.exceptions import Unknown @@ -877,6 +929,34 @@ def test_update_ddl_w_operation_id(self): metadata=[("google-cloud-resource-prefix", database.name)], ) + def test_update_ddl_w_proto_descriptors(self): + from tests._fixtures import DDL_STATEMENTS + from google.cloud.spanner_admin_database_v1 import UpdateDatabaseDdlRequest + + op_future = object() + client = _Client() + api = client.database_admin_api = self._make_database_admin_api() + api.update_database_ddl.return_value = op_future + instance = _Instance(self.INSTANCE_NAME, client=client) + pool = _Pool() + database = self._make_one(self.DATABASE_ID, instance, pool=pool) + + future = database.update_ddl(DDL_STATEMENTS, proto_descriptors=b"") + + self.assertIs(future, op_future) + + expected_request = UpdateDatabaseDdlRequest( + database=self.DATABASE_NAME, + statements=DDL_STATEMENTS, + operation_id="", + proto_descriptors=b"", + ) + + api.update_database_ddl.assert_called_once_with( + request=expected_request, + metadata=[("google-cloud-resource-prefix", database.name)], + ) + def test_drop_grpc_error(self): from google.api_core.exceptions import Unknown diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index e0a0f663cf..e45b3f051c 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -557,6 +557,7 @@ def test_database_factory_explicit(self): pool = _Pool() logger = mock.create_autospec(Logger, instance=True) encryption_config = {"kms_key_name": "kms_key_name"} + proto_descriptors = b"" database = instance.database( DATABASE_ID, @@ -565,6 +566,7 @@ def test_database_factory_explicit(self): logger=logger, encryption_config=encryption_config, database_role=DATABASE_ROLE, + proto_descriptors=proto_descriptors, ) self.assertIsInstance(database, Database) @@ -576,6 +578,7 @@ def test_database_factory_explicit(self): self.assertIs(pool._bound, database) self.assertIs(database._encryption_config, encryption_config) self.assertIs(database.database_role, DATABASE_ROLE) + self.assertIs(database._proto_descriptors, proto_descriptors) def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import Database as DatabasePB diff --git a/tests/unit/test_param_types.py b/tests/unit/test_param_types.py index 02f41c1f25..fad171c918 100644 --- a/tests/unit/test_param_types.py +++ b/tests/unit/test_param_types.py @@ -71,3 +71,37 @@ def test_it(self): found = param_types.PG_JSONB self.assertEqual(found, expected) + + +class Test_ProtoMessageParamType(unittest.TestCase): + def test_it(self): + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from google.cloud.spanner_v1 import param_types + from samples.samples.testdata import singer_pb2 + + singer_info = singer_pb2.SingerInfo() + expected = Type( + code=TypeCode.PROTO, proto_type_fqn=singer_info.DESCRIPTOR.full_name + ) + + found = param_types.ProtoMessage(singer_info) + + self.assertEqual(found, expected) + + +class Test_ProtoEnumParamType(unittest.TestCase): + def test_it(self): + from google.cloud.spanner_v1 import Type + from google.cloud.spanner_v1 import TypeCode + from google.cloud.spanner_v1 import param_types + from samples.samples.testdata import singer_pb2 + + singer_genre = singer_pb2.Genre + expected = Type( + code=TypeCode.ENUM, proto_type_fqn=singer_genre.DESCRIPTOR.full_name + ) + + found = param_types.ProtoEnum(singer_genre) + + self.assertEqual(found, expected) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index edad4ce777..ce9b205264 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -614,7 +614,12 @@ def test_read(self): self.assertIs(found, snapshot().read.return_value) snapshot().read.assert_called_once_with( - TABLE_NAME, COLUMNS, KEYSET, INDEX, LIMIT + TABLE_NAME, + COLUMNS, + KEYSET, + INDEX, + LIMIT, + column_info=None, ) def test_execute_sql_not_created(self): @@ -645,6 +650,7 @@ def test_execute_sql_defaults(self): request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, + column_info=None, ) def test_execute_sql_non_default_retry(self): @@ -675,6 +681,7 @@ def test_execute_sql_non_default_retry(self): request_options=None, timeout=None, retry=None, + column_info=None, ) def test_execute_sql_explicit(self): @@ -703,6 +710,7 @@ def test_execute_sql_explicit(self): request_options=None, timeout=google.api_core.gapic_v1.method.DEFAULT, retry=google.api_core.gapic_v1.method.DEFAULT, + column_info=None, ) def test_batch_not_created(self): From 0eddbcf4a5cee84e49a940ec56c8fd1eb122ccf9 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Mon, 29 Jan 2024 08:59:50 +0000 Subject: [PATCH 03/20] feat: rever autogenerated code --- .../types/spanner_database_admin.py | 195 +- google/cloud/spanner_v1/types/type.py | 30 +- ...ixup_spanner_admin_database_v1_keywords.py | 3 +- .../test_database_admin.py | 7265 ++++++++++++++++- 4 files changed, 7263 insertions(+), 230 deletions(-) diff --git a/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py b/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py index 163e49416e..b124e628d8 100644 --- a/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py +++ b/google/cloud/spanner_admin_database_v1/types/spanner_database_admin.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations + from typing import MutableMapping, MutableSequence import proto # type: ignore @@ -20,6 +22,7 @@ from google.cloud.spanner_admin_database_v1.types import backup as gsad_backup from google.cloud.spanner_admin_database_v1.types import common from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import timestamp_pb2 # type: ignore @@ -34,7 +37,10 @@ "CreateDatabaseRequest", "CreateDatabaseMetadata", "GetDatabaseRequest", + "UpdateDatabaseRequest", + "UpdateDatabaseMetadata", "UpdateDatabaseDdlRequest", + "DdlStatementActionInfo", "UpdateDatabaseDdlMetadata", "DropDatabaseRequest", "GetDatabaseDdlRequest", @@ -125,6 +131,7 @@ class Database(proto.Message): the encryption information for the database, such as encryption state and the Cloud KMS key versions that are in use. + For databases that are using Google default or other types of encryption, this field is empty. @@ -158,6 +165,13 @@ class Database(proto.Message): database_dialect (google.cloud.spanner_admin_database_v1.types.DatabaseDialect): Output only. The dialect of the Cloud Spanner Database. + enable_drop_protection (bool): + Whether drop protection is enabled for this + database. Defaults to false, if not set. + reconciling (bool): + Output only. If true, the database is being + updated. If false, there are no ongoing update + operations for the database. """ class State(proto.Enum): @@ -236,6 +250,14 @@ class State(proto.Enum): number=10, enum=common.DatabaseDialect, ) + enable_drop_protection: bool = proto.Field( + proto.BOOL, + number=11, + ) + reconciling: bool = proto.Field( + proto.BOOL, + number=12, + ) class ListDatabasesRequest(proto.Message): @@ -321,8 +343,10 @@ class CreateDatabaseRequest(proto.Message): inside the newly created database. Statements can create tables, indexes, etc. These statements execute atomically with the creation - of the database: if there is an error in any - statement, the database is not created. + of the database: + + if there is an error in any statement, the + database is not created. encryption_config (google.cloud.spanner_admin_database_v1.types.EncryptionConfig): Optional. The encryption configuration for the database. If this field is not specified, @@ -332,10 +356,25 @@ class CreateDatabaseRequest(proto.Message): Optional. The dialect of the Cloud Spanner Database. proto_descriptors (bytes): - Proto descriptors used by CREATE/ALTER PROTO BUNDLE - statements in 'extra_statements' above. Contains a + Optional. Proto descriptors used by CREATE/ALTER PROTO + BUNDLE statements in 'extra_statements' above. Contains a protobuf-serialized `google.protobuf.FileDescriptorSet `__. + To generate it, + `install `__ and + run ``protoc`` with --include_imports and + --descriptor_set_out. For example, to generate for + moon/shot/app.proto, run + + :: + + $protoc --proto_path=/app_path --proto_path=/lib_path \ + --include_imports \ + --descriptor_set_out=descriptors.data \ + moon/shot/app.proto + + For more details, see protobuffer `self + description `__. """ parent: str = proto.Field( @@ -398,6 +437,68 @@ class GetDatabaseRequest(proto.Message): ) +class UpdateDatabaseRequest(proto.Message): + r"""The request for + [UpdateDatabase][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabase]. + + Attributes: + database (google.cloud.spanner_admin_database_v1.types.Database): + Required. The database to update. The ``name`` field of the + database is of the form + ``projects//instances//databases/``. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, only + ``enable_drop_protection`` field can be updated. + """ + + database: "Database" = proto.Field( + proto.MESSAGE, + number=1, + message="Database", + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class UpdateDatabaseMetadata(proto.Message): + r"""Metadata type for the operation returned by + [UpdateDatabase][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabase]. + + Attributes: + request (google.cloud.spanner_admin_database_v1.types.UpdateDatabaseRequest): + The request for + [UpdateDatabase][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabase]. + progress (google.cloud.spanner_admin_database_v1.types.OperationProgress): + The progress of the + [UpdateDatabase][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabase] + operation. + cancel_time (google.protobuf.timestamp_pb2.Timestamp): + The time at which this operation was + cancelled. If set, this operation is in the + process of undoing itself (which is + best-effort). + """ + + request: "UpdateDatabaseRequest" = proto.Field( + proto.MESSAGE, + number=1, + message="UpdateDatabaseRequest", + ) + progress: common.OperationProgress = proto.Field( + proto.MESSAGE, + number=2, + message=common.OperationProgress, + ) + cancel_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + + class UpdateDatabaseDdlRequest(proto.Message): r"""Enqueues the given DDL statements to be applied, in order but not necessarily all at once, to the database schema at some point (or @@ -445,9 +546,24 @@ class UpdateDatabaseDdlRequest(proto.Message): [UpdateDatabaseDdl][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabaseDdl] returns ``ALREADY_EXISTS``. proto_descriptors (bytes): - Proto descriptors used by CREATE/ALTER PROTO BUNDLE - statements. Contains a protobuf-serialized + Optional. Proto descriptors used by CREATE/ALTER PROTO + BUNDLE statements. Contains a protobuf-serialized `google.protobuf.FileDescriptorSet `__. + To generate it, + `install `__ and + run ``protoc`` with --include_imports and + --descriptor_set_out. For example, to generate for + moon/shot/app.proto, run + + :: + + $protoc --proto_path=/app_path --proto_path=/lib_path \ + --include_imports \ + --descriptor_set_out=descriptors.data \ + moon/shot/app.proto + + For more details, see protobuffer `self + description `__. """ database: str = proto.Field( @@ -468,6 +584,46 @@ class UpdateDatabaseDdlRequest(proto.Message): ) +class DdlStatementActionInfo(proto.Message): + r"""Action information extracted from a DDL statement. This proto is + used to display the brief info of the DDL statement for the + operation + [UpdateDatabaseDdl][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabaseDdl]. + + Attributes: + action (str): + The action for the DDL statement, e.g. + CREATE, ALTER, DROP, GRANT, etc. This field is a + non-empty string. + entity_type (str): + The entity type for the DDL statement, e.g. TABLE, INDEX, + VIEW, etc. This field can be empty string for some DDL + statement, e.g. for statement "ANALYZE", ``entity_type`` = + "". + entity_names (MutableSequence[str]): + The entity name(s) being operated on the DDL statement. E.g. + + 1. For statement "CREATE TABLE t1(...)", ``entity_names`` = + ["t1"]. + 2. For statement "GRANT ROLE r1, r2 ...", ``entity_names`` = + ["r1", "r2"]. + 3. For statement "ANALYZE", ``entity_names`` = []. + """ + + action: str = proto.Field( + proto.STRING, + number=1, + ) + entity_type: str = proto.Field( + proto.STRING, + number=2, + ) + entity_names: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=3, + ) + + class UpdateDatabaseDdlMetadata(proto.Message): r"""Metadata type for the operation returned by [UpdateDatabaseDdl][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabaseDdl]. @@ -485,20 +641,22 @@ class UpdateDatabaseDdlMetadata(proto.Message): commit timestamp for the statement ``statements[i]``. throttled (bool): Output only. When true, indicates that the - operation is throttled e.g due to resource + operation is throttled e.g. due to resource constraints. When resources become available the operation will resume and this field will be false again. progress (MutableSequence[google.cloud.spanner_admin_database_v1.types.OperationProgress]): The progress of the [UpdateDatabaseDdl][google.spanner.admin.database.v1.DatabaseAdmin.UpdateDatabaseDdl] - operations. Currently, only index creation statements will - have a continuously updating progress. For non-index - creation statements, ``progress[i]`` will have start time - and end time populated with commit timestamp of operation, - as well as a progress of 100% once the operation has - completed. ``progress[i]`` is the operation progress for - ``statements[i]``. + operations. All DDL statements will have continuously + updating progress, and ``progress[i]`` is the operation + progress for ``statements[i]``. Also, ``progress[i]`` will + have start time and end time populated with commit timestamp + of operation, as well as a progress of 100% once the + operation has completed. + actions (MutableSequence[google.cloud.spanner_admin_database_v1.types.DdlStatementActionInfo]): + The brief action info for the DDL statements. ``actions[i]`` + is the brief info for ``statements[i]``. """ database: str = proto.Field( @@ -523,6 +681,11 @@ class UpdateDatabaseDdlMetadata(proto.Message): number=5, message=common.OperationProgress, ) + actions: MutableSequence["DdlStatementActionInfo"] = proto.RepeatedField( + proto.MESSAGE, + number=6, + message="DdlStatementActionInfo", + ) class DropDatabaseRequest(proto.Message): @@ -570,6 +733,8 @@ class GetDatabaseDdlResponse(proto.Message): Proto descriptors stored in the database. Contains a protobuf-serialized `google.protobuf.FileDescriptorSet `__. + For more details, see protobuffer `self + description `__. """ statements: MutableSequence[str] = proto.RepeatedField( diff --git a/google/cloud/spanner_v1/types/type.py b/google/cloud/spanner_v1/types/type.py index 0d378a2efa..235b851748 100644 --- a/google/cloud/spanner_v1/types/type.py +++ b/google/cloud/spanner_v1/types/type.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from __future__ import annotations + from typing import MutableMapping, MutableSequence import proto # type: ignore @@ -48,6 +50,9 @@ class TypeCode(proto.Enum): FLOAT64 (3): Encoded as ``number``, or the strings ``"NaN"``, ``"Infinity"``, or ``"-Infinity"``. + FLOAT32 (15): + Encoded as ``number``, or the strings ``"NaN"``, + ``"Infinity"``, or ``"-Infinity"``. TIMESTAMP (4): Encoded as ``string`` in RFC 3339 timestamp format. The time zone must be present, and must be ``"Z"``. @@ -92,11 +97,17 @@ class TypeCode(proto.Enum): - Members of a JSON object are not guaranteed to have their order preserved. - JSON array elements will have their order preserved. + PROTO (13): + Encoded as a base64-encoded ``string``, as described in RFC + 4648, section 4. + ENUM (14): + Encoded as ``string``, in decimal format. """ TYPE_CODE_UNSPECIFIED = 0 BOOL = 1 INT64 = 2 FLOAT64 = 3 + FLOAT32 = 15 TIMESTAMP = 4 DATE = 5 STRING = 6 @@ -137,10 +148,17 @@ class TypeAnnotationCode(proto.Enum): PostgreSQL JSONB values. Currently this annotation is always needed for [JSON][google.spanner.v1.TypeCode.JSON] when a client interacts with PostgreSQL-enabled Spanner databases. + PG_OID (4): + PostgreSQL compatible OID type. This + annotation can be used by a client interacting + with PostgreSQL-enabled Spanner database to + specify that a value should be treated using the + semantics of the OID type. """ TYPE_ANNOTATION_CODE_UNSPECIFIED = 0 PG_NUMERIC = 2 PG_JSONB = 3 + PG_OID = 4 class Type(proto.Message): @@ -173,10 +191,12 @@ class Type(proto.Message): (it doesn't affect serialization) and clients can ignore it on the read path. proto_type_fqn (str): - If [code][] == [PROTO][TypeCode.PROTO] or [code][] == - [ENUM][TypeCode.ENUM], then ``proto_type_fqn`` is the fully - qualified name of the proto type representing the proto/enum - definition. + If [code][google.spanner.v1.Type.code] == + [PROTO][google.spanner.v1.TypeCode.PROTO] or + [code][google.spanner.v1.Type.code] == + [ENUM][google.spanner.v1.TypeCode.ENUM], then + ``proto_type_fqn`` is the fully qualified name of the proto + type representing the proto/enum definition. """ code: "TypeCode" = proto.Field( diff --git a/scripts/fixup_spanner_admin_database_v1_keywords.py b/scripts/fixup_spanner_admin_database_v1_keywords.py index 6049c63dd2..dcba0a2eb4 100644 --- a/scripts/fixup_spanner_admin_database_v1_keywords.py +++ b/scripts/fixup_spanner_admin_database_v1_keywords.py @@ -1,6 +1,6 @@ #! /usr/bin/env python3 # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,6 +57,7 @@ class spanner_admin_databaseCallTransformer(cst.CSTTransformer): 'set_iam_policy': ('resource', 'policy', 'update_mask', ), 'test_iam_permissions': ('resource', 'permissions', ), 'update_backup': ('backup', 'update_mask', ), + 'update_database': ('database', 'update_mask', ), 'update_database_ddl': ('database', 'statements', 'operation_id', 'proto_descriptors', ), } diff --git a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py index 21cef8415f..a8292ee195 100644 --- a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py +++ b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,10 +24,17 @@ import grpc from grpc.experimental import aio +from collections.abc import Iterable +from google.protobuf import json_format +import json import math import pytest from proto.marshal.rules.dates import DurationRule, TimestampRule from proto.marshal.rules import wrappers +from requests import Response +from requests import Request, PreparedRequest +from requests.sessions import Session +from google.protobuf import json_format from google.api_core import client_options from google.api_core import exceptions as core_exceptions @@ -56,7 +63,7 @@ from google.iam.v1 import iam_policy_pb2 # type: ignore from google.iam.v1 import options_pb2 # type: ignore from google.iam.v1 import policy_pb2 # type: ignore -from google.longrunning import operations_pb2 +from google.longrunning import operations_pb2 # type: ignore from google.longrunning import operations_pb2 # type: ignore from google.oauth2 import service_account from google.protobuf import any_pb2 # type: ignore @@ -117,6 +124,7 @@ def test__get_default_mtls_endpoint(): [ (DatabaseAdminClient, "grpc"), (DatabaseAdminAsyncClient, "grpc_asyncio"), + (DatabaseAdminClient, "rest"), ], ) def test_database_admin_client_from_service_account_info(client_class, transport_name): @@ -130,7 +138,11 @@ def test_database_admin_client_from_service_account_info(client_class, transport assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("spanner.googleapis.com:443") + assert client.transport._host == ( + "spanner.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://spanner.googleapis.com" + ) @pytest.mark.parametrize( @@ -138,6 +150,7 @@ def test_database_admin_client_from_service_account_info(client_class, transport [ (transports.DatabaseAdminGrpcTransport, "grpc"), (transports.DatabaseAdminGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.DatabaseAdminRestTransport, "rest"), ], ) def test_database_admin_client_service_account_always_use_jwt( @@ -163,6 +176,7 @@ def test_database_admin_client_service_account_always_use_jwt( [ (DatabaseAdminClient, "grpc"), (DatabaseAdminAsyncClient, "grpc_asyncio"), + (DatabaseAdminClient, "rest"), ], ) def test_database_admin_client_from_service_account_file(client_class, transport_name): @@ -183,13 +197,18 @@ def test_database_admin_client_from_service_account_file(client_class, transport assert client.transport._credentials == creds assert isinstance(client, client_class) - assert client.transport._host == ("spanner.googleapis.com:443") + assert client.transport._host == ( + "spanner.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://spanner.googleapis.com" + ) def test_database_admin_client_get_transport_class(): transport = DatabaseAdminClient.get_transport_class() available_transports = [ transports.DatabaseAdminGrpcTransport, + transports.DatabaseAdminRestTransport, ] assert transport in available_transports @@ -206,6 +225,7 @@ def test_database_admin_client_get_transport_class(): transports.DatabaseAdminGrpcAsyncIOTransport, "grpc_asyncio", ), + (DatabaseAdminClient, transports.DatabaseAdminRestTransport, "rest"), ], ) @mock.patch.object( @@ -351,6 +371,8 @@ def test_database_admin_client_client_options( "grpc_asyncio", "false", ), + (DatabaseAdminClient, transports.DatabaseAdminRestTransport, "rest", "true"), + (DatabaseAdminClient, transports.DatabaseAdminRestTransport, "rest", "false"), ], ) @mock.patch.object( @@ -550,6 +572,7 @@ def test_database_admin_client_get_mtls_endpoint_and_cert_source(client_class): transports.DatabaseAdminGrpcAsyncIOTransport, "grpc_asyncio", ), + (DatabaseAdminClient, transports.DatabaseAdminRestTransport, "rest"), ], ) def test_database_admin_client_client_options_scopes( @@ -590,6 +613,7 @@ def test_database_admin_client_client_options_scopes( "grpc_asyncio", grpc_helpers_async, ), + (DatabaseAdminClient, transports.DatabaseAdminRestTransport, "rest", None), ], ) def test_database_admin_client_client_options_credentials_file( @@ -819,9 +843,9 @@ def test_list_databases_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -851,9 +875,9 @@ async def test_list_databases_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_databases_flattened(): @@ -1118,9 +1142,11 @@ async def test_list_databases_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch await client.list_databases(request={}) - ).pages: # pragma: no branch + ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -1234,9 +1260,9 @@ def test_create_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1266,9 +1292,9 @@ async def test_create_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_create_database_flattened(): @@ -1389,6 +1415,8 @@ def test_get_database(request_type, transport: str = "grpc"): version_retention_period="version_retention_period_value", default_leader="default_leader_value", database_dialect=common.DatabaseDialect.GOOGLE_STANDARD_SQL, + enable_drop_protection=True, + reconciling=True, ) response = client.get_database(request) @@ -1404,6 +1432,8 @@ def test_get_database(request_type, transport: str = "grpc"): assert response.version_retention_period == "version_retention_period_value" assert response.default_leader == "default_leader_value" assert response.database_dialect == common.DatabaseDialect.GOOGLE_STANDARD_SQL + assert response.enable_drop_protection is True + assert response.reconciling is True def test_get_database_empty_call(): @@ -1446,6 +1476,8 @@ async def test_get_database_async( version_retention_period="version_retention_period_value", default_leader="default_leader_value", database_dialect=common.DatabaseDialect.GOOGLE_STANDARD_SQL, + enable_drop_protection=True, + reconciling=True, ) ) response = await client.get_database(request) @@ -1462,6 +1494,8 @@ async def test_get_database_async( assert response.version_retention_period == "version_retention_period_value" assert response.default_leader == "default_leader_value" assert response.database_dialect == common.DatabaseDialect.GOOGLE_STANDARD_SQL + assert response.enable_drop_protection is True + assert response.reconciling is True @pytest.mark.asyncio @@ -1493,9 +1527,9 @@ def test_get_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1525,9 +1559,9 @@ async def test_get_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_database_flattened(): @@ -1612,6 +1646,243 @@ async def test_get_database_flattened_error_async(): ) +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.UpdateDatabaseRequest, + dict, + ], +) +def test_update_database(request_type, transport: str = "grpc"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.update_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == spanner_database_admin.UpdateDatabaseRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_update_database_empty_call(): + # This test is a coverage failsafe to make sure that totally empty calls, + # i.e. request == None and no flattened fields passed, work. + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + client.update_database() + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == spanner_database_admin.UpdateDatabaseRequest() + + +@pytest.mark.asyncio +async def test_update_database_async( + transport: str = "grpc_asyncio", + request_type=spanner_database_admin.UpdateDatabaseRequest, +): + client = DatabaseAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.update_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == spanner_database_admin.UpdateDatabaseRequest() + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_update_database_async_from_dict(): + await test_update_database_async(request_type=dict) + + +def test_update_database_field_headers(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = spanner_database_admin.UpdateDatabaseRequest() + + request.database.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.update_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "database.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_database_field_headers_async(): + client = DatabaseAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = spanner_database_admin.UpdateDatabaseRequest() + + request.database.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/op") + ) + await client.update_database(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "database.name=name_value", + ) in kw["metadata"] + + +def test_update_database_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_database( + database=spanner_database_admin.Database(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].database + mock_val = spanner_database_admin.Database(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_database_flattened_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_database( + spanner_database_admin.UpdateDatabaseRequest(), + database=spanner_database_admin.Database(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_database_flattened_async(): + client = DatabaseAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_database), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_database( + database=spanner_database_admin.Database(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].database + mock_val = spanner_database_admin.Database(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_database_flattened_error_async(): + client = DatabaseAdminAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_database( + spanner_database_admin.UpdateDatabaseRequest(), + database=spanner_database_admin.Database(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + @pytest.mark.parametrize( "request_type", [ @@ -1728,9 +1999,9 @@ def test_update_database_ddl_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1762,9 +2033,9 @@ async def test_update_database_ddl_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] def test_update_database_ddl_flattened(): @@ -1969,9 +2240,9 @@ def test_drop_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1999,9 +2270,9 @@ async def test_drop_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] def test_drop_database_flattened(): @@ -2202,9 +2473,9 @@ def test_get_database_ddl_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2234,9 +2505,9 @@ async def test_get_database_ddl_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] def test_get_database_ddl_flattened(): @@ -2438,9 +2709,9 @@ def test_set_iam_policy_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2468,9 +2739,9 @@ async def test_set_iam_policy_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] def test_set_iam_policy_from_dict_foreign(): @@ -2688,9 +2959,9 @@ def test_get_iam_policy_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2718,9 +2989,9 @@ async def test_get_iam_policy_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] def test_get_iam_policy_from_dict_foreign(): @@ -2942,9 +3213,9 @@ def test_test_iam_permissions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2976,9 +3247,9 @@ async def test_test_iam_permissions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] def test_test_iam_permissions_from_dict_foreign(): @@ -3203,9 +3474,9 @@ def test_create_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -3235,9 +3506,9 @@ async def test_create_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_create_backup_flattened(): @@ -3449,9 +3720,9 @@ def test_copy_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -3481,9 +3752,9 @@ async def test_copy_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_copy_backup_flattened(): @@ -3735,9 +4006,9 @@ def test_get_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -3765,9 +4036,9 @@ async def test_get_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_backup_flattened(): @@ -3987,9 +4258,9 @@ def test_update_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "backup.name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "backup.name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4017,9 +4288,9 @@ async def test_update_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "backup.name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "backup.name=name_value", + ) in kw["metadata"] def test_update_backup_flattened(): @@ -4217,9 +4488,9 @@ def test_delete_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4247,9 +4518,9 @@ async def test_delete_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_delete_backup_flattened(): @@ -4445,9 +4716,9 @@ def test_list_backups_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4477,9 +4748,9 @@ async def test_list_backups_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_backups_flattened(): @@ -4744,9 +5015,11 @@ async def test_list_backups_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch await client.list_backups(request={}) - ).pages: # pragma: no branch + ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -4860,9 +5133,9 @@ def test_restore_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4892,9 +5165,9 @@ async def test_restore_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_restore_database_flattened(): @@ -5117,9 +5390,9 @@ def test_list_database_operations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -5151,9 +5424,9 @@ async def test_list_database_operations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_database_operations_flattened(): @@ -5430,9 +5703,11 @@ async def test_list_database_operations_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch await client.list_database_operations(request={}) - ).pages: # pragma: no branch + ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -5559,9 +5834,9 @@ def test_list_backup_operations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -5593,9 +5868,9 @@ async def test_list_backup_operations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_backup_operations_flattened(): @@ -5872,9 +6147,11 @@ async def test_list_backup_operations_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch await client.list_backup_operations(request={}) - ).pages: # pragma: no branch + ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token @@ -6002,9 +6279,9 @@ def test_list_database_roles_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -6036,9 +6313,9 @@ async def test_list_database_roles_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_database_roles_flattened(): @@ -6317,76 +6594,6273 @@ async def test_list_database_roles_async_pages(): RuntimeError, ) pages = [] - async for page_ in ( + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch await client.list_database_roles(request={}) - ).pages: # pragma: no branch + ).pages: pages.append(page_) for page_, token in zip(pages, ["abc", "def", "ghi", ""]): assert page_.raw_page.next_page_token == token -def test_credentials_transport_error(): - # It is an error to provide credentials and a transport instance. - transport = transports.DatabaseAdminGrpcTransport( +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.ListDatabasesRequest, + dict, + ], +) +def test_list_databases_rest(request_type): + client = DatabaseAdminClient( credentials=ga_credentials.AnonymousCredentials(), + transport="rest", ) - with pytest.raises(ValueError): - client = DatabaseAdminClient( - credentials=ga_credentials.AnonymousCredentials(), - transport=transport, - ) - # It is an error to provide a credentials file and a transport instance. - transport = transports.DatabaseAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatabaseAdminClient( - client_options={"credentials_file": "credentials.json"}, - transport=transport, - ) + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) - # It is an error to provide an api_key and a transport instance. - transport = transports.DatabaseAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - options = client_options.ClientOptions() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = DatabaseAdminClient( - client_options=options, - transport=transport, + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabasesResponse( + next_page_token="next_page_token_value", ) - # It is an error to provide an api_key and a credential. - options = mock.Mock() - options.api_key = "api_key" - with pytest.raises(ValueError): - client = DatabaseAdminClient( - client_options=options, credentials=ga_credentials.AnonymousCredentials() - ) + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabasesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) - # It is an error to provide scopes and a transport instance. - transport = transports.DatabaseAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - with pytest.raises(ValueError): - client = DatabaseAdminClient( - client_options={"scopes": ["1", "2"]}, - transport=transport, - ) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_databases(request) + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatabasesPager) + assert response.next_page_token == "next_page_token_value" -def test_transport_instance(): - # A client may be instantiated with a custom transport instance. - transport = transports.DatabaseAdminGrpcTransport( - credentials=ga_credentials.AnonymousCredentials(), - ) - client = DatabaseAdminClient(transport=transport) - assert client.transport is transport +def test_list_databases_rest_required_fields( + request_type=spanner_database_admin.ListDatabasesRequest, +): + transport_class = transports.DatabaseAdminRestTransport -def test_transport_get_channel(): + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_databases._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_databases._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabasesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabasesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_databases(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_databases_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_databases._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_databases_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_list_databases" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_list_databases" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.ListDatabasesRequest.pb( + spanner_database_admin.ListDatabasesRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + spanner_database_admin.ListDatabasesResponse.to_json( + spanner_database_admin.ListDatabasesResponse() + ) + ) + + request = spanner_database_admin.ListDatabasesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = spanner_database_admin.ListDatabasesResponse() + + client.list_databases( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_databases_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.ListDatabasesRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_databases(request) + + +def test_list_databases_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabasesResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabasesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_databases(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/databases" % client.transport._host, + args[1], + ) + + +def test_list_databases_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_databases( + spanner_database_admin.ListDatabasesRequest(), + parent="parent_value", + ) + + +def test_list_databases_rest_pager(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + spanner_database_admin.ListDatabasesResponse( + databases=[ + spanner_database_admin.Database(), + spanner_database_admin.Database(), + spanner_database_admin.Database(), + ], + next_page_token="abc", + ), + spanner_database_admin.ListDatabasesResponse( + databases=[], + next_page_token="def", + ), + spanner_database_admin.ListDatabasesResponse( + databases=[ + spanner_database_admin.Database(), + ], + next_page_token="ghi", + ), + spanner_database_admin.ListDatabasesResponse( + databases=[ + spanner_database_admin.Database(), + spanner_database_admin.Database(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + spanner_database_admin.ListDatabasesResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/instances/sample2"} + + pager = client.list_databases(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, spanner_database_admin.Database) for i in results) + + pages = list(client.list_databases(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.CreateDatabaseRequest, + dict, + ], +) +def test_create_database_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_database(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_database_rest_required_fields( + request_type=spanner_database_admin.CreateDatabaseRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request_init["create_statement"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + jsonified_request["createStatement"] = "create_statement_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "createStatement" in jsonified_request + assert jsonified_request["createStatement"] == "create_statement_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_database_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_database._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "createStatement", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_database_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_create_database" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_create_database" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.CreateDatabaseRequest.pb( + spanner_database_admin.CreateDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = spanner_database_admin.CreateDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_database_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.CreateDatabaseRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_database(request) + + +def test_create_database_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + create_statement="create_statement_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_database(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/databases" % client.transport._host, + args[1], + ) + + +def test_create_database_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_database( + spanner_database_admin.CreateDatabaseRequest(), + parent="parent_value", + create_statement="create_statement_value", + ) + + +def test_create_database_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.GetDatabaseRequest, + dict, + ], +) +def test_get_database_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.Database( + name="name_value", + state=spanner_database_admin.Database.State.CREATING, + version_retention_period="version_retention_period_value", + default_leader="default_leader_value", + database_dialect=common.DatabaseDialect.GOOGLE_STANDARD_SQL, + enable_drop_protection=True, + reconciling=True, + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.Database.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_database(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, spanner_database_admin.Database) + assert response.name == "name_value" + assert response.state == spanner_database_admin.Database.State.CREATING + assert response.version_retention_period == "version_retention_period_value" + assert response.default_leader == "default_leader_value" + assert response.database_dialect == common.DatabaseDialect.GOOGLE_STANDARD_SQL + assert response.enable_drop_protection is True + assert response.reconciling is True + + +def test_get_database_rest_required_fields( + request_type=spanner_database_admin.GetDatabaseRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.Database() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = spanner_database_admin.Database.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_database_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_database._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_database_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_get_database" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_get_database" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.GetDatabaseRequest.pb( + spanner_database_admin.GetDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = spanner_database_admin.Database.to_json( + spanner_database_admin.Database() + ) + + request = spanner_database_admin.GetDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = spanner_database_admin.Database() + + client.get_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_database_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.GetDatabaseRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_database(request) + + +def test_get_database_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.Database() + + # get arguments that satisfy an http rule for this method + sample_request = { + "name": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.Database.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_database(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/instances/*/databases/*}" % client.transport._host, + args[1], + ) + + +def test_get_database_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_database( + spanner_database_admin.GetDatabaseRequest(), + name="name_value", + ) + + +def test_get_database_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.UpdateDatabaseRequest, + dict, + ], +) +def test_update_database_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "database": {"name": "projects/sample1/instances/sample2/databases/sample3"} + } + request_init["database"] = { + "name": "projects/sample1/instances/sample2/databases/sample3", + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "restore_info": { + "source_type": 1, + "backup_info": { + "backup": "backup_value", + "version_time": {}, + "create_time": {}, + "source_database": "source_database_value", + }, + }, + "encryption_config": {"kms_key_name": "kms_key_name_value"}, + "encryption_info": [ + { + "encryption_type": 1, + "encryption_status": { + "code": 411, + "message": "message_value", + "details": [ + { + "type_url": "type.googleapis.com/google.protobuf.Duration", + "value": b"\x08\x0c\x10\xdb\x07", + } + ], + }, + "kms_key_version": "kms_key_version_value", + } + ], + "version_retention_period": "version_retention_period_value", + "earliest_version_time": {}, + "default_leader": "default_leader_value", + "database_dialect": 1, + "enable_drop_protection": True, + "reconciling": True, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = spanner_database_admin.UpdateDatabaseRequest.meta.fields["database"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["database"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["database"][field])): + del request_init["database"][field][i][subfield] + else: + del request_init["database"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_database(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_update_database_rest_required_fields( + request_type=spanner_database_admin.UpdateDatabaseRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_database._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_database_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_database._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "database", + "updateMask", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_database_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_update_database" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_update_database" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.UpdateDatabaseRequest.pb( + spanner_database_admin.UpdateDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = spanner_database_admin.UpdateDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.update_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_database_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.UpdateDatabaseRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "database": {"name": "projects/sample1/instances/sample2/databases/sample3"} + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_database(request) + + +def test_update_database_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "database": {"name": "projects/sample1/instances/sample2/databases/sample3"} + } + + # get truthy value for each flattened field + mock_args = dict( + database=spanner_database_admin.Database(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.update_database(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{database.name=projects/*/instances/*/databases/*}" + % client.transport._host, + args[1], + ) + + +def test_update_database_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_database( + spanner_database_admin.UpdateDatabaseRequest(), + database=spanner_database_admin.Database(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_update_database_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.UpdateDatabaseDdlRequest, + dict, + ], +) +def test_update_database_ddl_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_database_ddl(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_update_database_ddl_rest_required_fields( + request_type=spanner_database_admin.UpdateDatabaseDdlRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["database"] = "" + request_init["statements"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_database_ddl._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["database"] = "database_value" + jsonified_request["statements"] = "statements_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_database_ddl._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" + assert "statements" in jsonified_request + assert jsonified_request["statements"] == "statements_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_database_ddl(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_database_ddl_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_database_ddl._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "database", + "statements", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_database_ddl_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_update_database_ddl" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_update_database_ddl" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.UpdateDatabaseDdlRequest.pb( + spanner_database_admin.UpdateDatabaseDdlRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = spanner_database_admin.UpdateDatabaseDdlRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.update_database_ddl( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_database_ddl_rest_bad_request( + transport: str = "rest", + request_type=spanner_database_admin.UpdateDatabaseDdlRequest, +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_database_ddl(request) + + +def test_update_database_ddl_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = { + "database": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + database="database_value", + statements=["statements_value"], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.update_database_ddl(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{database=projects/*/instances/*/databases/*}/ddl" + % client.transport._host, + args[1], + ) + + +def test_update_database_ddl_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_database_ddl( + spanner_database_admin.UpdateDatabaseDdlRequest(), + database="database_value", + statements=["statements_value"], + ) + + +def test_update_database_ddl_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.DropDatabaseRequest, + dict, + ], +) +def test_drop_database_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.drop_database(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_drop_database_rest_required_fields( + request_type=spanner_database_admin.DropDatabaseRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["database"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).drop_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["database"] = "database_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).drop_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.drop_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_drop_database_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.drop_database._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("database",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_drop_database_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_drop_database" + ) as pre: + pre.assert_not_called() + pb_message = spanner_database_admin.DropDatabaseRequest.pb( + spanner_database_admin.DropDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = spanner_database_admin.DropDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.drop_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_drop_database_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.DropDatabaseRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.drop_database(request) + + +def test_drop_database_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = { + "database": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + database="database_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.drop_database(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{database=projects/*/instances/*/databases/*}" + % client.transport._host, + args[1], + ) + + +def test_drop_database_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.drop_database( + spanner_database_admin.DropDatabaseRequest(), + database="database_value", + ) + + +def test_drop_database_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.GetDatabaseDdlRequest, + dict, + ], +) +def test_get_database_ddl_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.GetDatabaseDdlResponse( + statements=["statements_value"], + proto_descriptors=b"proto_descriptors_blob", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.GetDatabaseDdlResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_database_ddl(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, spanner_database_admin.GetDatabaseDdlResponse) + assert response.statements == ["statements_value"] + assert response.proto_descriptors == b"proto_descriptors_blob" + + +def test_get_database_ddl_rest_required_fields( + request_type=spanner_database_admin.GetDatabaseDdlRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["database"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_database_ddl._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["database"] = "database_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_database_ddl._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "database" in jsonified_request + assert jsonified_request["database"] == "database_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.GetDatabaseDdlResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = spanner_database_admin.GetDatabaseDdlResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_database_ddl(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_database_ddl_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_database_ddl._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("database",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_database_ddl_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_get_database_ddl" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_get_database_ddl" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.GetDatabaseDdlRequest.pb( + spanner_database_admin.GetDatabaseDdlRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + spanner_database_admin.GetDatabaseDdlResponse.to_json( + spanner_database_admin.GetDatabaseDdlResponse() + ) + ) + + request = spanner_database_admin.GetDatabaseDdlRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = spanner_database_admin.GetDatabaseDdlResponse() + + client.get_database_ddl( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_database_ddl_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.GetDatabaseDdlRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"database": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_database_ddl(request) + + +def test_get_database_ddl_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.GetDatabaseDdlResponse() + + # get arguments that satisfy an http rule for this method + sample_request = { + "database": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + database="database_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.GetDatabaseDdlResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_database_ddl(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{database=projects/*/instances/*/databases/*}/ddl" + % client.transport._host, + args[1], + ) + + +def test_get_database_ddl_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_database_ddl( + spanner_database_admin.GetDatabaseDdlRequest(), + database="database_value", + ) + + +def test_get_database_ddl_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + iam_policy_pb2.SetIamPolicyRequest, + dict, + ], +) +def test_set_iam_policy_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"resource": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.set_iam_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + assert response.version == 774 + assert response.etag == b"etag_blob" + + +def test_set_iam_policy_rest_required_fields( + request_type=iam_policy_pb2.SetIamPolicyRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["resource"] = "" + request = request_type(**request_init) + pb_request = request + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).set_iam_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["resource"] = "resource_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).set_iam_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "resource" in jsonified_request + assert jsonified_request["resource"] == "resource_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.set_iam_policy(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_set_iam_policy_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.set_iam_policy._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "resource", + "policy", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_set_iam_policy_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_set_iam_policy" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_set_iam_policy" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = iam_policy_pb2.SetIamPolicyRequest() + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson(policy_pb2.Policy()) + + request = iam_policy_pb2.SetIamPolicyRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = policy_pb2.Policy() + + client.set_iam_policy( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_set_iam_policy_rest_bad_request( + transport: str = "rest", request_type=iam_policy_pb2.SetIamPolicyRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"resource": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.set_iam_policy(request) + + +def test_set_iam_policy_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy() + + # get arguments that satisfy an http rule for this method + sample_request = { + "resource": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + resource="resource_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.set_iam_policy(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{resource=projects/*/instances/*/databases/*}:setIamPolicy" + % client.transport._host, + args[1], + ) + + +def test_set_iam_policy_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.set_iam_policy( + iam_policy_pb2.SetIamPolicyRequest(), + resource="resource_value", + ) + + +def test_set_iam_policy_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + iam_policy_pb2.GetIamPolicyRequest, + dict, + ], +) +def test_get_iam_policy_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"resource": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy( + version=774, + etag=b"etag_blob", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_iam_policy(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, policy_pb2.Policy) + assert response.version == 774 + assert response.etag == b"etag_blob" + + +def test_get_iam_policy_rest_required_fields( + request_type=iam_policy_pb2.GetIamPolicyRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["resource"] = "" + request = request_type(**request_init) + pb_request = request + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_iam_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["resource"] = "resource_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_iam_policy._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "resource" in jsonified_request + assert jsonified_request["resource"] == "resource_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_iam_policy(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_iam_policy_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_iam_policy._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("resource",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_iam_policy_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_get_iam_policy" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_get_iam_policy" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = iam_policy_pb2.GetIamPolicyRequest() + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson(policy_pb2.Policy()) + + request = iam_policy_pb2.GetIamPolicyRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = policy_pb2.Policy() + + client.get_iam_policy( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_iam_policy_rest_bad_request( + transport: str = "rest", request_type=iam_policy_pb2.GetIamPolicyRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"resource": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_iam_policy(request) + + +def test_get_iam_policy_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = policy_pb2.Policy() + + # get arguments that satisfy an http rule for this method + sample_request = { + "resource": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + resource="resource_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_iam_policy(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{resource=projects/*/instances/*/databases/*}:getIamPolicy" + % client.transport._host, + args[1], + ) + + +def test_get_iam_policy_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_iam_policy( + iam_policy_pb2.GetIamPolicyRequest(), + resource="resource_value", + ) + + +def test_get_iam_policy_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + iam_policy_pb2.TestIamPermissionsRequest, + dict, + ], +) +def test_test_iam_permissions_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"resource": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = iam_policy_pb2.TestIamPermissionsResponse( + permissions=["permissions_value"], + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.test_iam_permissions(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, iam_policy_pb2.TestIamPermissionsResponse) + assert response.permissions == ["permissions_value"] + + +def test_test_iam_permissions_rest_required_fields( + request_type=iam_policy_pb2.TestIamPermissionsRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["resource"] = "" + request_init["permissions"] = "" + request = request_type(**request_init) + pb_request = request + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).test_iam_permissions._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["resource"] = "resource_value" + jsonified_request["permissions"] = "permissions_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).test_iam_permissions._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "resource" in jsonified_request + assert jsonified_request["resource"] == "resource_value" + assert "permissions" in jsonified_request + assert jsonified_request["permissions"] == "permissions_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = iam_policy_pb2.TestIamPermissionsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.test_iam_permissions(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_test_iam_permissions_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.test_iam_permissions._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "resource", + "permissions", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_test_iam_permissions_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_test_iam_permissions" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_test_iam_permissions" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = iam_policy_pb2.TestIamPermissionsRequest() + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + iam_policy_pb2.TestIamPermissionsResponse() + ) + + request = iam_policy_pb2.TestIamPermissionsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = iam_policy_pb2.TestIamPermissionsResponse() + + client.test_iam_permissions( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_test_iam_permissions_rest_bad_request( + transport: str = "rest", request_type=iam_policy_pb2.TestIamPermissionsRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"resource": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.test_iam_permissions(request) + + +def test_test_iam_permissions_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = iam_policy_pb2.TestIamPermissionsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = { + "resource": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + resource="resource_value", + permissions=["permissions_value"], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.test_iam_permissions(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{resource=projects/*/instances/*/databases/*}:testIamPermissions" + % client.transport._host, + args[1], + ) + + +def test_test_iam_permissions_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.test_iam_permissions( + iam_policy_pb2.TestIamPermissionsRequest(), + resource="resource_value", + permissions=["permissions_value"], + ) + + +def test_test_iam_permissions_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + gsad_backup.CreateBackupRequest, + dict, + ], +) +def test_create_backup_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request_init["backup"] = { + "database": "database_value", + "version_time": {"seconds": 751, "nanos": 543}, + "expire_time": {}, + "name": "name_value", + "create_time": {}, + "size_bytes": 1089, + "state": 1, + "referencing_databases": [ + "referencing_databases_value1", + "referencing_databases_value2", + ], + "encryption_info": { + "encryption_type": 1, + "encryption_status": { + "code": 411, + "message": "message_value", + "details": [ + { + "type_url": "type.googleapis.com/google.protobuf.Duration", + "value": b"\x08\x0c\x10\xdb\x07", + } + ], + }, + "kms_key_version": "kms_key_version_value", + }, + "database_dialect": 1, + "referencing_backups": [ + "referencing_backups_value1", + "referencing_backups_value2", + ], + "max_expire_time": {}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = gsad_backup.CreateBackupRequest.meta.fields["backup"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["backup"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["backup"][field])): + del request_init["backup"][field][i][subfield] + else: + del request_init["backup"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.create_backup(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_create_backup_rest_required_fields( + request_type=gsad_backup.CreateBackupRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request_init["backup_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + assert "backupId" not in jsonified_request + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + assert "backupId" in jsonified_request + assert jsonified_request["backupId"] == request_init["backup_id"] + + jsonified_request["parent"] = "parent_value" + jsonified_request["backupId"] = "backup_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_backup._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "backup_id", + "encryption_config", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "backupId" in jsonified_request + assert jsonified_request["backupId"] == "backup_id_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.create_backup(request) + + expected_params = [ + ( + "backupId", + "", + ), + ("$alt", "json;enum-encoding=int"), + ] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_backup_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_backup._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "backupId", + "encryptionConfig", + ) + ) + & set( + ( + "parent", + "backupId", + "backup", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_backup_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_create_backup" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_create_backup" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = gsad_backup.CreateBackupRequest.pb( + gsad_backup.CreateBackupRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = gsad_backup.CreateBackupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_backup( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_backup_rest_bad_request( + transport: str = "rest", request_type=gsad_backup.CreateBackupRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.create_backup(request) + + +def test_create_backup_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + backup=gsad_backup.Backup(database="database_value"), + backup_id="backup_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.create_backup(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/backups" % client.transport._host, + args[1], + ) + + +def test_create_backup_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_backup( + gsad_backup.CreateBackupRequest(), + parent="parent_value", + backup=gsad_backup.Backup(database="database_value"), + backup_id="backup_id_value", + ) + + +def test_create_backup_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + backup.CopyBackupRequest, + dict, + ], +) +def test_copy_backup_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.copy_backup(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_copy_backup_rest_required_fields(request_type=backup.CopyBackupRequest): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request_init["backup_id"] = "" + request_init["source_backup"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).copy_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + jsonified_request["backupId"] = "backup_id_value" + jsonified_request["sourceBackup"] = "source_backup_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).copy_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "backupId" in jsonified_request + assert jsonified_request["backupId"] == "backup_id_value" + assert "sourceBackup" in jsonified_request + assert jsonified_request["sourceBackup"] == "source_backup_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.copy_backup(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_copy_backup_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.copy_backup._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "backupId", + "sourceBackup", + "expireTime", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_copy_backup_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_copy_backup" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_copy_backup" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = backup.CopyBackupRequest.pb(backup.CopyBackupRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = backup.CopyBackupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.copy_backup( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_copy_backup_rest_bad_request( + transport: str = "rest", request_type=backup.CopyBackupRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.copy_backup(request) + + +def test_copy_backup_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + backup_id="backup_id_value", + source_backup="source_backup_value", + expire_time=timestamp_pb2.Timestamp(seconds=751), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.copy_backup(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/backups:copy" + % client.transport._host, + args[1], + ) + + +def test_copy_backup_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.copy_backup( + backup.CopyBackupRequest(), + parent="parent_value", + backup_id="backup_id_value", + source_backup="source_backup_value", + expire_time=timestamp_pb2.Timestamp(seconds=751), + ) + + +def test_copy_backup_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + backup.GetBackupRequest, + dict, + ], +) +def test_get_backup_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/instances/sample2/backups/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = backup.Backup( + database="database_value", + name="name_value", + size_bytes=1089, + state=backup.Backup.State.CREATING, + referencing_databases=["referencing_databases_value"], + database_dialect=common.DatabaseDialect.GOOGLE_STANDARD_SQL, + referencing_backups=["referencing_backups_value"], + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = backup.Backup.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.get_backup(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, backup.Backup) + assert response.database == "database_value" + assert response.name == "name_value" + assert response.size_bytes == 1089 + assert response.state == backup.Backup.State.CREATING + assert response.referencing_databases == ["referencing_databases_value"] + assert response.database_dialect == common.DatabaseDialect.GOOGLE_STANDARD_SQL + assert response.referencing_backups == ["referencing_backups_value"] + + +def test_get_backup_rest_required_fields(request_type=backup.GetBackupRequest): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = backup.Backup() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = backup.Backup.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_backup(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_backup_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_backup._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_backup_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_get_backup" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_get_backup" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = backup.GetBackupRequest.pb(backup.GetBackupRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = backup.Backup.to_json(backup.Backup()) + + request = backup.GetBackupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = backup.Backup() + + client.get_backup( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_backup_rest_bad_request( + transport: str = "rest", request_type=backup.GetBackupRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/instances/sample2/backups/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_backup(request) + + +def test_get_backup_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = backup.Backup() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "projects/sample1/instances/sample2/backups/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = backup.Backup.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.get_backup(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/instances/*/backups/*}" % client.transport._host, + args[1], + ) + + +def test_get_backup_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_backup( + backup.GetBackupRequest(), + name="name_value", + ) + + +def test_get_backup_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + gsad_backup.UpdateBackupRequest, + dict, + ], +) +def test_update_backup_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = { + "backup": {"name": "projects/sample1/instances/sample2/backups/sample3"} + } + request_init["backup"] = { + "database": "database_value", + "version_time": {"seconds": 751, "nanos": 543}, + "expire_time": {}, + "name": "projects/sample1/instances/sample2/backups/sample3", + "create_time": {}, + "size_bytes": 1089, + "state": 1, + "referencing_databases": [ + "referencing_databases_value1", + "referencing_databases_value2", + ], + "encryption_info": { + "encryption_type": 1, + "encryption_status": { + "code": 411, + "message": "message_value", + "details": [ + { + "type_url": "type.googleapis.com/google.protobuf.Duration", + "value": b"\x08\x0c\x10\xdb\x07", + } + ], + }, + "kms_key_version": "kms_key_version_value", + }, + "database_dialect": 1, + "referencing_backups": [ + "referencing_backups_value1", + "referencing_backups_value2", + ], + "max_expire_time": {}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = gsad_backup.UpdateBackupRequest.meta.fields["backup"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["backup"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["backup"][field])): + del request_init["backup"][field][i][subfield] + else: + del request_init["backup"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gsad_backup.Backup( + database="database_value", + name="name_value", + size_bytes=1089, + state=gsad_backup.Backup.State.CREATING, + referencing_databases=["referencing_databases_value"], + database_dialect=common.DatabaseDialect.GOOGLE_STANDARD_SQL, + referencing_backups=["referencing_backups_value"], + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gsad_backup.Backup.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.update_backup(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gsad_backup.Backup) + assert response.database == "database_value" + assert response.name == "name_value" + assert response.size_bytes == 1089 + assert response.state == gsad_backup.Backup.State.CREATING + assert response.referencing_databases == ["referencing_databases_value"] + assert response.database_dialect == common.DatabaseDialect.GOOGLE_STANDARD_SQL + assert response.referencing_backups == ["referencing_backups_value"] + + +def test_update_backup_rest_required_fields( + request_type=gsad_backup.UpdateBackupRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_backup._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gsad_backup.Backup() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gsad_backup.Backup.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.update_backup(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_backup_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_backup._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "backup", + "updateMask", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_backup_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_update_backup" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_update_backup" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = gsad_backup.UpdateBackupRequest.pb( + gsad_backup.UpdateBackupRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = gsad_backup.Backup.to_json(gsad_backup.Backup()) + + request = gsad_backup.UpdateBackupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gsad_backup.Backup() + + client.update_backup( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_backup_rest_bad_request( + transport: str = "rest", request_type=gsad_backup.UpdateBackupRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = { + "backup": {"name": "projects/sample1/instances/sample2/backups/sample3"} + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.update_backup(request) + + +def test_update_backup_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gsad_backup.Backup() + + # get arguments that satisfy an http rule for this method + sample_request = { + "backup": {"name": "projects/sample1/instances/sample2/backups/sample3"} + } + + # get truthy value for each flattened field + mock_args = dict( + backup=gsad_backup.Backup(database="database_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gsad_backup.Backup.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.update_backup(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{backup.name=projects/*/instances/*/backups/*}" + % client.transport._host, + args[1], + ) + + +def test_update_backup_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_backup( + gsad_backup.UpdateBackupRequest(), + backup=gsad_backup.Backup(database="database_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_update_backup_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + backup.DeleteBackupRequest, + dict, + ], +) +def test_delete_backup_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/instances/sample2/backups/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.delete_backup(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_backup_rest_required_fields(request_type=backup.DeleteBackupRequest): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_backup._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_backup(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_backup_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_backup._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_backup_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_delete_backup" + ) as pre: + pre.assert_not_called() + pb_message = backup.DeleteBackupRequest.pb(backup.DeleteBackupRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + + request = backup.DeleteBackupRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_backup( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_delete_backup_rest_bad_request( + transport: str = "rest", request_type=backup.DeleteBackupRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"name": "projects/sample1/instances/sample2/backups/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_backup(request) + + +def test_delete_backup_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "projects/sample1/instances/sample2/backups/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.delete_backup(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{name=projects/*/instances/*/backups/*}" % client.transport._host, + args[1], + ) + + +def test_delete_backup_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_backup( + backup.DeleteBackupRequest(), + name="name_value", + ) + + +def test_delete_backup_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + backup.ListBackupsRequest, + dict, + ], +) +def test_list_backups_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = backup.ListBackupsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = backup.ListBackupsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_backups(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBackupsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_backups_rest_required_fields(request_type=backup.ListBackupsRequest): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_backups._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_backups._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = backup.ListBackupsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = backup.ListBackupsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_backups(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_backups_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_backups._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "filter", + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_backups_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_list_backups" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_list_backups" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = backup.ListBackupsRequest.pb(backup.ListBackupsRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = backup.ListBackupsResponse.to_json( + backup.ListBackupsResponse() + ) + + request = backup.ListBackupsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = backup.ListBackupsResponse() + + client.list_backups( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_backups_rest_bad_request( + transport: str = "rest", request_type=backup.ListBackupsRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_backups(request) + + +def test_list_backups_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = backup.ListBackupsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = backup.ListBackupsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_backups(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/backups" % client.transport._host, + args[1], + ) + + +def test_list_backups_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_backups( + backup.ListBackupsRequest(), + parent="parent_value", + ) + + +def test_list_backups_rest_pager(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + backup.ListBackupsResponse( + backups=[ + backup.Backup(), + backup.Backup(), + backup.Backup(), + ], + next_page_token="abc", + ), + backup.ListBackupsResponse( + backups=[], + next_page_token="def", + ), + backup.ListBackupsResponse( + backups=[ + backup.Backup(), + ], + next_page_token="ghi", + ), + backup.ListBackupsResponse( + backups=[ + backup.Backup(), + backup.Backup(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(backup.ListBackupsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/instances/sample2"} + + pager = client.list_backups(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, backup.Backup) for i in results) + + pages = list(client.list_backups(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.RestoreDatabaseRequest, + dict, + ], +) +def test_restore_database_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.restore_database(request) + + # Establish that the response is the type that we expect. + assert response.operation.name == "operations/spam" + + +def test_restore_database_rest_required_fields( + request_type=spanner_database_admin.RestoreDatabaseRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request_init["database_id"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).restore_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + jsonified_request["databaseId"] = "database_id_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).restore_database._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + assert "databaseId" in jsonified_request + assert jsonified_request["databaseId"] == "database_id_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.restore_database(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_restore_database_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.restore_database._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "databaseId", + ) + ) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_restore_database_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_restore_database" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_restore_database" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.RestoreDatabaseRequest.pb( + spanner_database_admin.RestoreDatabaseRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = json_format.MessageToJson( + operations_pb2.Operation() + ) + + request = spanner_database_admin.RestoreDatabaseRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.restore_database( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_restore_database_rest_bad_request( + transport: str = "rest", request_type=spanner_database_admin.RestoreDatabaseRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.restore_database(request) + + +def test_restore_database_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + database_id="database_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.restore_database(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/databases:restore" + % client.transport._host, + args[1], + ) + + +def test_restore_database_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.restore_database( + spanner_database_admin.RestoreDatabaseRequest(), + parent="parent_value", + database_id="database_id_value", + backup="backup_value", + ) + + +def test_restore_database_rest_error(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.ListDatabaseOperationsRequest, + dict, + ], +) +def test_list_database_operations_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabaseOperationsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabaseOperationsResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_database_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatabaseOperationsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_database_operations_rest_required_fields( + request_type=spanner_database_admin.ListDatabaseOperationsRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_database_operations._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_database_operations._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabaseOperationsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabaseOperationsResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_database_operations(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_database_operations_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_database_operations._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "filter", + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_database_operations_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_list_database_operations" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_list_database_operations" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.ListDatabaseOperationsRequest.pb( + spanner_database_admin.ListDatabaseOperationsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + spanner_database_admin.ListDatabaseOperationsResponse.to_json( + spanner_database_admin.ListDatabaseOperationsResponse() + ) + ) + + request = spanner_database_admin.ListDatabaseOperationsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = spanner_database_admin.ListDatabaseOperationsResponse() + + client.list_database_operations( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_database_operations_rest_bad_request( + transport: str = "rest", + request_type=spanner_database_admin.ListDatabaseOperationsRequest, +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_database_operations(request) + + +def test_list_database_operations_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabaseOperationsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabaseOperationsResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_database_operations(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/databaseOperations" + % client.transport._host, + args[1], + ) + + +def test_list_database_operations_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_database_operations( + spanner_database_admin.ListDatabaseOperationsRequest(), + parent="parent_value", + ) + + +def test_list_database_operations_rest_pager(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + spanner_database_admin.ListDatabaseOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + next_page_token="abc", + ), + spanner_database_admin.ListDatabaseOperationsResponse( + operations=[], + next_page_token="def", + ), + spanner_database_admin.ListDatabaseOperationsResponse( + operations=[ + operations_pb2.Operation(), + ], + next_page_token="ghi", + ), + spanner_database_admin.ListDatabaseOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + spanner_database_admin.ListDatabaseOperationsResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/instances/sample2"} + + pager = client.list_database_operations(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, operations_pb2.Operation) for i in results) + + pages = list(client.list_database_operations(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + backup.ListBackupOperationsRequest, + dict, + ], +) +def test_list_backup_operations_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = backup.ListBackupOperationsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = backup.ListBackupOperationsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_backup_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListBackupOperationsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_backup_operations_rest_required_fields( + request_type=backup.ListBackupOperationsRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_backup_operations._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_backup_operations._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "filter", + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = backup.ListBackupOperationsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = backup.ListBackupOperationsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_backup_operations(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_backup_operations_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_backup_operations._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "filter", + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_backup_operations_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_list_backup_operations" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_list_backup_operations" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = backup.ListBackupOperationsRequest.pb( + backup.ListBackupOperationsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = backup.ListBackupOperationsResponse.to_json( + backup.ListBackupOperationsResponse() + ) + + request = backup.ListBackupOperationsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = backup.ListBackupOperationsResponse() + + client.list_backup_operations( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_backup_operations_rest_bad_request( + transport: str = "rest", request_type=backup.ListBackupOperationsRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_backup_operations(request) + + +def test_list_backup_operations_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = backup.ListBackupOperationsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "projects/sample1/instances/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = backup.ListBackupOperationsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_backup_operations(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*}/backupOperations" + % client.transport._host, + args[1], + ) + + +def test_list_backup_operations_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_backup_operations( + backup.ListBackupOperationsRequest(), + parent="parent_value", + ) + + +def test_list_backup_operations_rest_pager(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + backup.ListBackupOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + next_page_token="abc", + ), + backup.ListBackupOperationsResponse( + operations=[], + next_page_token="def", + ), + backup.ListBackupOperationsResponse( + operations=[ + operations_pb2.Operation(), + ], + next_page_token="ghi", + ), + backup.ListBackupOperationsResponse( + operations=[ + operations_pb2.Operation(), + operations_pb2.Operation(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + backup.ListBackupOperationsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "projects/sample1/instances/sample2"} + + pager = client.list_backup_operations(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, operations_pb2.Operation) for i in results) + + pages = list(client.list_backup_operations(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + spanner_database_admin.ListDatabaseRolesRequest, + dict, + ], +) +def test_list_database_roles_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabaseRolesResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabaseRolesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + response = client.list_database_roles(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDatabaseRolesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_database_roles_rest_required_fields( + request_type=spanner_database_admin.ListDatabaseRolesRequest, +): + transport_class = transports.DatabaseAdminRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson( + pb_request, + including_default_value_fields=False, + use_integers_for_enums=False, + ) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_database_roles._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_database_roles._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabaseRolesResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabaseRolesResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_database_roles(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_database_roles_rest_unset_required_fields(): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_database_roles._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_database_roles_rest_interceptors(null_interceptor): + transport = transports.DatabaseAdminRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DatabaseAdminRestInterceptor(), + ) + client = DatabaseAdminClient(transport=transport) + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "post_list_database_roles" + ) as post, mock.patch.object( + transports.DatabaseAdminRestInterceptor, "pre_list_database_roles" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = spanner_database_admin.ListDatabaseRolesRequest.pb( + spanner_database_admin.ListDatabaseRolesRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + req.return_value._content = ( + spanner_database_admin.ListDatabaseRolesResponse.to_json( + spanner_database_admin.ListDatabaseRolesResponse() + ) + ) + + request = spanner_database_admin.ListDatabaseRolesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = spanner_database_admin.ListDatabaseRolesResponse() + + client.list_database_roles( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_database_roles_rest_bad_request( + transport: str = "rest", + request_type=spanner_database_admin.ListDatabaseRolesRequest, +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "projects/sample1/instances/sample2/databases/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_database_roles(request) + + +def test_list_database_roles_rest_flattened(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = spanner_database_admin.ListDatabaseRolesResponse() + + # get arguments that satisfy an http rule for this method + sample_request = { + "parent": "projects/sample1/instances/sample2/databases/sample3" + } + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = spanner_database_admin.ListDatabaseRolesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + client.list_database_roles(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1/{parent=projects/*/instances/*/databases/*}/databaseRoles" + % client.transport._host, + args[1], + ) + + +def test_list_database_roles_rest_flattened_error(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_database_roles( + spanner_database_admin.ListDatabaseRolesRequest(), + parent="parent_value", + ) + + +def test_list_database_roles_rest_pager(transport: str = "rest"): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + spanner_database_admin.ListDatabaseRolesResponse( + database_roles=[ + spanner_database_admin.DatabaseRole(), + spanner_database_admin.DatabaseRole(), + spanner_database_admin.DatabaseRole(), + ], + next_page_token="abc", + ), + spanner_database_admin.ListDatabaseRolesResponse( + database_roles=[], + next_page_token="def", + ), + spanner_database_admin.ListDatabaseRolesResponse( + database_roles=[ + spanner_database_admin.DatabaseRole(), + ], + next_page_token="ghi", + ), + spanner_database_admin.ListDatabaseRolesResponse( + database_roles=[ + spanner_database_admin.DatabaseRole(), + spanner_database_admin.DatabaseRole(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + spanner_database_admin.ListDatabaseRolesResponse.to_json(x) + for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = { + "parent": "projects/sample1/instances/sample2/databases/sample3" + } + + pager = client.list_database_roles(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, spanner_database_admin.DatabaseRole) for i in results) + + pages = list(client.list_database_roles(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DatabaseAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DatabaseAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatabaseAdminClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.DatabaseAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DatabaseAdminClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = mock.Mock() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DatabaseAdminClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DatabaseAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DatabaseAdminClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DatabaseAdminGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = DatabaseAdminClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): # A client may be instantiated with a custom transport instance. transport = transports.DatabaseAdminGrpcTransport( credentials=ga_credentials.AnonymousCredentials(), @@ -6406,6 +12880,7 @@ def test_transport_get_channel(): [ transports.DatabaseAdminGrpcTransport, transports.DatabaseAdminGrpcAsyncIOTransport, + transports.DatabaseAdminRestTransport, ], ) def test_transport_adc(transport_class): @@ -6420,6 +12895,7 @@ def test_transport_adc(transport_class): "transport_name", [ "grpc", + "rest", ], ) def test_transport_kind(transport_name): @@ -6465,6 +12941,7 @@ def test_database_admin_base_transport(): "list_databases", "create_database", "get_database", + "update_database", "update_database_ddl", "drop_database", "get_database_ddl", @@ -6585,6 +13062,7 @@ def test_database_admin_transport_auth_adc(transport_class): [ transports.DatabaseAdminGrpcTransport, transports.DatabaseAdminGrpcAsyncIOTransport, + transports.DatabaseAdminRestTransport, ], ) def test_database_admin_transport_auth_gdch_credentials(transport_class): @@ -6685,11 +13163,40 @@ def test_database_admin_grpc_transport_client_cert_source_for_mtls(transport_cla ) +def test_database_admin_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.DatabaseAdminRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +def test_database_admin_rest_lro_client(): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + @pytest.mark.parametrize( "transport_name", [ "grpc", "grpc_asyncio", + "rest", ], ) def test_database_admin_host_no_port(transport_name): @@ -6700,7 +13207,11 @@ def test_database_admin_host_no_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("spanner.googleapis.com:443") + assert client.transport._host == ( + "spanner.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://spanner.googleapis.com" + ) @pytest.mark.parametrize( @@ -6708,6 +13219,7 @@ def test_database_admin_host_no_port(transport_name): [ "grpc", "grpc_asyncio", + "rest", ], ) def test_database_admin_host_with_port(transport_name): @@ -6718,7 +13230,90 @@ def test_database_admin_host_with_port(transport_name): ), transport=transport_name, ) - assert client.transport._host == ("spanner.googleapis.com:8000") + assert client.transport._host == ( + "spanner.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://spanner.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_database_admin_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = DatabaseAdminClient( + credentials=creds1, + transport=transport_name, + ) + client2 = DatabaseAdminClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.list_databases._session + session2 = client2.transport.list_databases._session + assert session1 != session2 + session1 = client1.transport.create_database._session + session2 = client2.transport.create_database._session + assert session1 != session2 + session1 = client1.transport.get_database._session + session2 = client2.transport.get_database._session + assert session1 != session2 + session1 = client1.transport.update_database._session + session2 = client2.transport.update_database._session + assert session1 != session2 + session1 = client1.transport.update_database_ddl._session + session2 = client2.transport.update_database_ddl._session + assert session1 != session2 + session1 = client1.transport.drop_database._session + session2 = client2.transport.drop_database._session + assert session1 != session2 + session1 = client1.transport.get_database_ddl._session + session2 = client2.transport.get_database_ddl._session + assert session1 != session2 + session1 = client1.transport.set_iam_policy._session + session2 = client2.transport.set_iam_policy._session + assert session1 != session2 + session1 = client1.transport.get_iam_policy._session + session2 = client2.transport.get_iam_policy._session + assert session1 != session2 + session1 = client1.transport.test_iam_permissions._session + session2 = client2.transport.test_iam_permissions._session + assert session1 != session2 + session1 = client1.transport.create_backup._session + session2 = client2.transport.create_backup._session + assert session1 != session2 + session1 = client1.transport.copy_backup._session + session2 = client2.transport.copy_backup._session + assert session1 != session2 + session1 = client1.transport.get_backup._session + session2 = client2.transport.get_backup._session + assert session1 != session2 + session1 = client1.transport.update_backup._session + session2 = client2.transport.update_backup._session + assert session1 != session2 + session1 = client1.transport.delete_backup._session + session2 = client2.transport.delete_backup._session + assert session1 != session2 + session1 = client1.transport.list_backups._session + session2 = client2.transport.list_backups._session + assert session1 != session2 + session1 = client1.transport.restore_database._session + session2 = client2.transport.restore_database._session + assert session1 != session2 + session1 = client1.transport.list_database_operations._session + session2 = client2.transport.list_database_operations._session + assert session1 != session2 + session1 = client1.transport.list_backup_operations._session + session2 = client2.transport.list_backup_operations._session + assert session1 != session2 + session1 = client1.transport.list_database_roles._session + session2 = client2.transport.list_database_roles._session + assert session1 != session2 def test_database_admin_grpc_transport_channel(): @@ -7188,6 +13783,256 @@ async def test_transport_close_async(): close.assert_called_once() +def test_cancel_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.CancelOperationRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + { + "name": "projects/sample1/instances/sample2/databases/sample3/operations/sample4" + }, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.cancel_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.CancelOperationRequest, + dict, + ], +) +def test_cancel_operation_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "name": "projects/sample1/instances/sample2/databases/sample3/operations/sample4" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "{}" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.cancel_operation(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.DeleteOperationRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + { + "name": "projects/sample1/instances/sample2/databases/sample3/operations/sample4" + }, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.delete_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.DeleteOperationRequest, + dict, + ], +) +def test_delete_operation_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "name": "projects/sample1/instances/sample2/databases/sample3/operations/sample4" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "{}" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.delete_operation(request) + + # Establish that the response is the type that we expect. + assert response is None + + +def test_get_operation_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.GetOperationRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + { + "name": "projects/sample1/instances/sample2/databases/sample3/operations/sample4" + }, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "name": "projects/sample1/instances/sample2/databases/sample3/operations/sample4" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + transport: str = "rest", request_type=operations_pb2.ListOperationsRequest +): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + request = request_type() + request = json_format.ParseDict( + {"name": "projects/sample1/instances/sample2/databases/sample3/operations"}, + request, + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = DatabaseAdminClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request_init = { + "name": "projects/sample1/instances/sample2/databases/sample3/operations" + } + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + def test_delete_operation(transport: str = "grpc"): client = DatabaseAdminClient( credentials=ga_credentials.AnonymousCredentials(), @@ -7213,7 +14058,7 @@ def test_delete_operation(transport: str = "grpc"): @pytest.mark.asyncio -async def test_delete_operation_async(transport: str = "grpc"): +async def test_delete_operation_async(transport: str = "grpc_asyncio"): client = DatabaseAdminAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7260,9 +14105,9 @@ def test_delete_operation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -7288,9 +14133,9 @@ async def test_delete_operation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_delete_operation_from_dict(): @@ -7352,7 +14197,7 @@ def test_cancel_operation(transport: str = "grpc"): @pytest.mark.asyncio -async def test_cancel_operation_async(transport: str = "grpc"): +async def test_cancel_operation_async(transport: str = "grpc_asyncio"): client = DatabaseAdminAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7399,9 +14244,9 @@ def test_cancel_operation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -7427,9 +14272,9 @@ async def test_cancel_operation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_cancel_operation_from_dict(): @@ -7491,7 +14336,7 @@ def test_get_operation(transport: str = "grpc"): @pytest.mark.asyncio -async def test_get_operation_async(transport: str = "grpc"): +async def test_get_operation_async(transport: str = "grpc_asyncio"): client = DatabaseAdminAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7540,9 +14385,9 @@ def test_get_operation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -7570,9 +14415,9 @@ async def test_get_operation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_get_operation_from_dict(): @@ -7636,7 +14481,7 @@ def test_list_operations(transport: str = "grpc"): @pytest.mark.asyncio -async def test_list_operations_async(transport: str = "grpc"): +async def test_list_operations_async(transport: str = "grpc_asyncio"): client = DatabaseAdminAsyncClient( credentials=ga_credentials.AnonymousCredentials(), transport=transport, @@ -7685,9 +14530,9 @@ def test_list_operations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -7715,9 +14560,9 @@ async def test_list_operations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_list_operations_from_dict(): @@ -7758,6 +14603,7 @@ async def test_list_operations_from_dict_async(): def test_transport_close(): transports = { + "rest": "_session", "grpc": "_grpc_channel", } @@ -7775,6 +14621,7 @@ def test_transport_close(): def test_client_ctx(): transports = [ + "rest", "grpc", ] for transport in transports: From f44d51d405b162f8cd62a3331280270a1eedd956 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Mon, 29 Jan 2024 09:04:25 +0000 Subject: [PATCH 04/20] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- .../test_database_admin.py | 406 +++++++++--------- 1 file changed, 203 insertions(+), 203 deletions(-) diff --git a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py index a8292ee195..6f9f99b5d1 100644 --- a/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py +++ b/tests/unit/gapic/spanner_admin_database_v1/test_database_admin.py @@ -843,9 +843,9 @@ def test_list_databases_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -875,9 +875,9 @@ async def test_list_databases_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_databases_flattened(): @@ -1260,9 +1260,9 @@ def test_create_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1292,9 +1292,9 @@ async def test_create_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_create_database_flattened(): @@ -1527,9 +1527,9 @@ def test_get_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1559,9 +1559,9 @@ async def test_get_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_database_flattened(): @@ -1754,9 +1754,9 @@ def test_update_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database.name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "database.name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -1786,9 +1786,9 @@ async def test_update_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database.name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "database.name=name_value", + ) in kw["metadata"] def test_update_database_flattened(): @@ -1999,9 +1999,9 @@ def test_update_database_ddl_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2033,9 +2033,9 @@ async def test_update_database_ddl_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] def test_update_database_ddl_flattened(): @@ -2240,9 +2240,9 @@ def test_drop_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2270,9 +2270,9 @@ async def test_drop_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] def test_drop_database_flattened(): @@ -2473,9 +2473,9 @@ def test_get_database_ddl_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2505,9 +2505,9 @@ async def test_get_database_ddl_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "database=database_value", - ) in kw["metadata"] + "x-goog-request-params", + "database=database_value", + ) in kw["metadata"] def test_get_database_ddl_flattened(): @@ -2709,9 +2709,9 @@ def test_set_iam_policy_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2739,9 +2739,9 @@ async def test_set_iam_policy_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] def test_set_iam_policy_from_dict_foreign(): @@ -2959,9 +2959,9 @@ def test_get_iam_policy_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -2989,9 +2989,9 @@ async def test_get_iam_policy_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] def test_get_iam_policy_from_dict_foreign(): @@ -3213,9 +3213,9 @@ def test_test_iam_permissions_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -3247,9 +3247,9 @@ async def test_test_iam_permissions_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "resource=resource_value", - ) in kw["metadata"] + "x-goog-request-params", + "resource=resource_value", + ) in kw["metadata"] def test_test_iam_permissions_from_dict_foreign(): @@ -3474,9 +3474,9 @@ def test_create_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -3506,9 +3506,9 @@ async def test_create_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_create_backup_flattened(): @@ -3720,9 +3720,9 @@ def test_copy_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -3752,9 +3752,9 @@ async def test_copy_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_copy_backup_flattened(): @@ -4006,9 +4006,9 @@ def test_get_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4036,9 +4036,9 @@ async def test_get_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_get_backup_flattened(): @@ -4258,9 +4258,9 @@ def test_update_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "backup.name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "backup.name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4288,9 +4288,9 @@ async def test_update_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "backup.name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "backup.name=name_value", + ) in kw["metadata"] def test_update_backup_flattened(): @@ -4488,9 +4488,9 @@ def test_delete_backup_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4518,9 +4518,9 @@ async def test_delete_backup_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=name_value", - ) in kw["metadata"] + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] def test_delete_backup_flattened(): @@ -4716,9 +4716,9 @@ def test_list_backups_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -4748,9 +4748,9 @@ async def test_list_backups_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_backups_flattened(): @@ -5133,9 +5133,9 @@ def test_restore_database_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -5165,9 +5165,9 @@ async def test_restore_database_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_restore_database_flattened(): @@ -5390,9 +5390,9 @@ def test_list_database_operations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -5424,9 +5424,9 @@ async def test_list_database_operations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_database_operations_flattened(): @@ -5834,9 +5834,9 @@ def test_list_backup_operations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -5868,9 +5868,9 @@ async def test_list_backup_operations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_backup_operations_flattened(): @@ -6279,9 +6279,9 @@ def test_list_database_roles_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] @pytest.mark.asyncio @@ -6313,9 +6313,9 @@ async def test_list_database_roles_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "parent=parent_value", - ) in kw["metadata"] + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] def test_list_database_roles_flattened(): @@ -6867,7 +6867,7 @@ def test_list_databases_rest_flattened(): assert path_template.validate( "%s/v1/{parent=projects/*/instances/*}/databases" % client.transport._host, args[1], - ) + ) def test_list_databases_rest_flattened_error(transport: str = "rest"): @@ -7072,12 +7072,12 @@ def test_create_database_rest_unset_required_fields(): assert set(unset_fields) == ( set(()) & set( - ( - "parent", - "createStatement", + ( + "parent", + "createStatement", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -7200,7 +7200,7 @@ def test_create_database_rest_flattened(): assert path_template.validate( "%s/v1/{parent=projects/*/instances/*}/databases" % client.transport._host, args[1], - ) + ) def test_create_database_rest_flattened_error(transport: str = "rest"): @@ -7485,7 +7485,7 @@ def test_get_database_rest_flattened(): assert path_template.validate( "%s/v1/{name=projects/*/instances/*/databases/*}" % client.transport._host, args[1], - ) + ) def test_get_database_rest_flattened_error(transport: str = "rest"): @@ -7732,12 +7732,12 @@ def test_update_database_rest_unset_required_fields(): assert set(unset_fields) == ( set(("updateMask",)) & set( - ( - "database", - "updateMask", + ( + "database", + "updateMask", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -7865,7 +7865,7 @@ def test_update_database_rest_flattened(): "%s/v1/{database.name=projects/*/instances/*/databases/*}" % client.transport._host, args[1], - ) + ) def test_update_database_rest_flattened_error(transport: str = "rest"): @@ -8014,12 +8014,12 @@ def test_update_database_ddl_rest_unset_required_fields(): assert set(unset_fields) == ( set(()) & set( - ( - "database", - "statements", + ( + "database", + "statements", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -8146,7 +8146,7 @@ def test_update_database_ddl_rest_flattened(): "%s/v1/{database=projects/*/instances/*/databases/*}/ddl" % client.transport._host, args[1], - ) + ) def test_update_database_ddl_rest_flattened_error(transport: str = "rest"): @@ -8402,7 +8402,7 @@ def test_drop_database_rest_flattened(): "%s/v1/{database=projects/*/instances/*/databases/*}" % client.transport._host, args[1], - ) + ) def test_drop_database_rest_flattened_error(transport: str = "rest"): @@ -8681,7 +8681,7 @@ def test_get_database_ddl_rest_flattened(): "%s/v1/{database=projects/*/instances/*/databases/*}/ddl" % client.transport._host, args[1], - ) + ) def test_get_database_ddl_rest_flattened_error(transport: str = "rest"): @@ -8831,12 +8831,12 @@ def test_set_iam_policy_rest_unset_required_fields(): assert set(unset_fields) == ( set(()) & set( - ( - "resource", - "policy", + ( + "resource", + "policy", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -8955,7 +8955,7 @@ def test_set_iam_policy_rest_flattened(): "%s/v1/{resource=projects/*/instances/*/databases/*}:setIamPolicy" % client.transport._host, args[1], - ) + ) def test_set_iam_policy_rest_flattened_error(transport: str = "rest"): @@ -9221,7 +9221,7 @@ def test_get_iam_policy_rest_flattened(): "%s/v1/{resource=projects/*/instances/*/databases/*}:getIamPolicy" % client.transport._host, args[1], - ) + ) def test_get_iam_policy_rest_flattened_error(transport: str = "rest"): @@ -9373,12 +9373,12 @@ def test_test_iam_permissions_rest_unset_required_fields(): assert set(unset_fields) == ( set(()) & set( - ( - "resource", - "permissions", + ( + "resource", + "permissions", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -9500,7 +9500,7 @@ def test_test_iam_permissions_rest_flattened(): "%s/v1/{resource=projects/*/instances/*/databases/*}:testIamPermissions" % client.transport._host, args[1], - ) + ) def test_test_iam_permissions_rest_flattened_error(transport: str = "rest"): @@ -9770,13 +9770,13 @@ def test_create_backup_rest_unset_required_fields(): ) ) & set( - ( - "parent", - "backupId", - "backup", + ( + "parent", + "backupId", + "backup", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -9900,7 +9900,7 @@ def test_create_backup_rest_flattened(): assert path_template.validate( "%s/v1/{parent=projects/*/instances/*}/backups" % client.transport._host, args[1], - ) + ) def test_create_backup_rest_flattened_error(transport: str = "rest"): @@ -10052,14 +10052,14 @@ def test_copy_backup_rest_unset_required_fields(): assert set(unset_fields) == ( set(()) & set( - ( - "parent", - "backupId", - "sourceBackup", - "expireTime", + ( + "parent", + "backupId", + "sourceBackup", + "expireTime", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -10183,7 +10183,7 @@ def test_copy_backup_rest_flattened(): "%s/v1/{parent=projects/*/instances/*}/backups:copy" % client.transport._host, args[1], - ) + ) def test_copy_backup_rest_flattened_error(transport: str = "rest"): @@ -10462,7 +10462,7 @@ def test_get_backup_rest_flattened(): assert path_template.validate( "%s/v1/{name=projects/*/instances/*/backups/*}" % client.transport._host, args[1], - ) + ) def test_get_backup_rest_flattened_error(transport: str = "rest"): @@ -10725,12 +10725,12 @@ def test_update_backup_rest_unset_required_fields(): assert set(unset_fields) == ( set(("updateMask",)) & set( - ( - "backup", - "updateMask", + ( + "backup", + "updateMask", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -10856,7 +10856,7 @@ def test_update_backup_rest_flattened(): "%s/v1/{backup.name=projects/*/instances/*/backups/*}" % client.transport._host, args[1], - ) + ) def test_update_backup_rest_flattened_error(transport: str = "rest"): @@ -11105,7 +11105,7 @@ def test_delete_backup_rest_flattened(): assert path_template.validate( "%s/v1/{name=projects/*/instances/*/backups/*}" % client.transport._host, args[1], - ) + ) def test_delete_backup_rest_flattened_error(transport: str = "rest"): @@ -11388,7 +11388,7 @@ def test_list_backups_rest_flattened(): assert path_template.validate( "%s/v1/{parent=projects/*/instances/*}/backups" % client.transport._host, args[1], - ) + ) def test_list_backups_rest_flattened_error(transport: str = "rest"): @@ -11591,12 +11591,12 @@ def test_restore_database_rest_unset_required_fields(): assert set(unset_fields) == ( set(()) & set( - ( - "parent", - "databaseId", + ( + "parent", + "databaseId", + ) ) ) - ) @pytest.mark.parametrize("null_interceptor", [True, False]) @@ -11720,7 +11720,7 @@ def test_restore_database_rest_flattened(): "%s/v1/{parent=projects/*/instances/*}/databases:restore" % client.transport._host, args[1], - ) + ) def test_restore_database_rest_flattened_error(transport: str = "rest"): @@ -12019,7 +12019,7 @@ def test_list_database_operations_rest_flattened(): "%s/v1/{parent=projects/*/instances/*}/databaseOperations" % client.transport._host, args[1], - ) + ) def test_list_database_operations_rest_flattened_error(transport: str = "rest"): @@ -12365,7 +12365,7 @@ def test_list_backup_operations_rest_flattened(): "%s/v1/{parent=projects/*/instances/*}/backupOperations" % client.transport._host, args[1], - ) + ) def test_list_backup_operations_rest_flattened_error(transport: str = "rest"): @@ -12715,7 +12715,7 @@ def test_list_database_roles_rest_flattened(): "%s/v1/{parent=projects/*/instances/*/databases/*}/databaseRoles" % client.transport._host, args[1], - ) + ) def test_list_database_roles_rest_flattened_error(transport: str = "rest"): @@ -14105,9 +14105,9 @@ def test_delete_operation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -14133,9 +14133,9 @@ async def test_delete_operation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_delete_operation_from_dict(): @@ -14244,9 +14244,9 @@ def test_cancel_operation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -14272,9 +14272,9 @@ async def test_cancel_operation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_cancel_operation_from_dict(): @@ -14385,9 +14385,9 @@ def test_get_operation_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -14415,9 +14415,9 @@ async def test_get_operation_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_get_operation_from_dict(): @@ -14530,9 +14530,9 @@ def test_list_operations_field_headers(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] @pytest.mark.asyncio @@ -14560,9 +14560,9 @@ async def test_list_operations_field_headers_async(): # Establish that the field header was sent. _, _, kw = call.mock_calls[0] assert ( - "x-goog-request-params", - "name=locations", - ) in kw["metadata"] + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] def test_list_operations_from_dict(): From 2f6f9d7800653870b1ca5872a7f24a4678696005 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Mon, 29 Jan 2024 10:31:24 +0000 Subject: [PATCH 05/20] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- google/cloud/spanner_v1/snapshot.py | 8 ++++++-- tests/system/test_database_api.py | 16 +++++++++++++--- tests/system/test_session_api.py | 1 - tests/unit/test_database.py | 1 - 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 8ba9011a82..f9295ce3c2 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -483,9 +483,13 @@ def execute_sql( if self._transaction_id is None: # lock is added to handle the inline begin for first rpc with self._lock: - return self._get_streamed_result_set(restart, request, trace_attributes, column_info) + return self._get_streamed_result_set( + restart, request, trace_attributes, column_info + ) else: - return self._get_streamed_result_set(restart, request, trace_attributes, column_info) + return self._get_streamed_result_set( + restart, request, trace_attributes, column_info + ) def _get_streamed_result_set(self, restart, request, trace_attributes, column_info): iterator = _restart_on_unavailable( diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 8a145ff447..db1d48c8a2 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -444,7 +444,11 @@ def test_update_ddl_w_default_leader_success( def test_create_role_grant_access_success( - not_emulator, shared_instance, databases_to_delete, database_dialect, proto_descriptor_file, + not_emulator, + shared_instance, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") @@ -469,7 +473,9 @@ def test_create_role_grant_access_success( f"GRANT SELECT ON TABLE contacts TO {creator_role_parent}", ] - operation = temp_db.update_ddl(ddl_statements, proto_descriptors=proto_descriptor_file) + operation = temp_db.update_ddl( + ddl_statements, proto_descriptors=proto_descriptor_file + ) operation.result(DBAPI_OPERATION_TIMEOUT) # raises on failure / timeout. # Perform select with orphan role on table contacts. @@ -504,7 +510,11 @@ def test_create_role_grant_access_success( def test_list_database_role_success( - not_emulator, shared_instance, databases_to_delete, database_dialect, proto_descriptor_file + not_emulator, + shared_instance, + databases_to_delete, + database_dialect, + proto_descriptor_file, ): creator_role_parent = _helpers.unique_id("role_parent", separator="_") creator_role_orphan = _helpers.unique_id("role_orphan", separator="_") diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 0e6085bbf8..c8c7234c38 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -1338,7 +1338,6 @@ def _unit_of_work(transaction): def _set_up_proto_table(database): - sd = _sample_data def _unit_of_work(transaction): diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index f8961897c1..11754cbe86 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -390,7 +390,6 @@ def test_default_leader(self): self.assertEqual(database.default_leader, default_leader) def test_proto_descriptors(self): - instance = _Instance(self.INSTANCE_NAME) pool = _Pool() database = self._make_one( From 84d8ef64bc8dc0543ff1c7dc7b2a5af77f43ce2a Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Mon, 29 Jan 2024 10:45:06 +0000 Subject: [PATCH 06/20] fix: fix code --- google/cloud/spanner_v1/instance.py | 1 - samples/samples/snippets.py | 4 ++-- testing/constraints-3.7.txt | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/google/cloud/spanner_v1/instance.py b/google/cloud/spanner_v1/instance.py index 7d4c67274f..a67e0e630b 100644 --- a/google/cloud/spanner_v1/instance.py +++ b/google/cloud/spanner_v1/instance.py @@ -511,7 +511,6 @@ def database( database_dialect=database_dialect, database_role=database_role, enable_drop_protection=enable_drop_protection, - proto_descriptors=proto_descriptors, ) def list_databases(self, page_size=None): diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 209e7eb839..c8c8be45fc 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -40,7 +40,7 @@ from google.type import expr_pb2 from google.iam.v1 import policy_pb2 from google.protobuf import field_mask_pb2 # type: ignore -from samples.samples.testdata import singer_pb2 +from testdata import singer_pb2 OPERATION_TIMEOUT_SECONDS = 240 @@ -2617,7 +2617,7 @@ def insert_proto_columns_data(instance_id, database_id): The database and table must already exist and can be created using `create_database`. """ - spanner_client = spanner.Client(client_options={'api_endpoint':'staging-wrenchworks.sandbox.googleapis.com'}) + spanner_client = spanner.Client() instance = spanner_client.instance(instance_id) database = instance.database(database_id) diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 6783c9586c..20170203f5 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -14,6 +14,5 @@ opentelemetry-api==1.1.0 opentelemetry-sdk==1.1.0 opentelemetry-instrumentation==0.20b0 protobuf==3.20.2 -protobuf==3.20.2 deprecated==1.2.14 grpc-interceptor==0.15.4 From b1bcf343d1672205425829284933ac2c19a620b7 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Mon, 29 Jan 2024 11:00:11 +0000 Subject: [PATCH 07/20] fix: fix code --- tests/system/_helpers.py | 4 +++- tests/system/test_session_api.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index b62d453512..3a708e0247 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -65,7 +65,9 @@ ) ) -PROTO_COLUMNS_DDL_STATEMENTS = _fixtures.PROTO_COLUMNS_DDL_STATEMENTS +PROTO_COLUMNS_DDL_STATEMENTS = ( + [] if USE_EMULATOR else _fixtures.PROTO_COLUMNS_DDL_STATEMENTS +) retry_true = retry.RetryResult(operator.truth) retry_false = retry.RetryResult(operator.not_) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index c8c7234c38..b881cf081c 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -90,7 +90,7 @@ "proto_enum_array", ) -EMULATOR_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:-4] +EMULATOR_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:-8] # ToDo: Clean up generation of POSTGRES_ALL_TYPES_COLUMNS POSTGRES_ALL_TYPES_COLUMNS = LIVE_ALL_TYPES_COLUMNS[:17] + ( "jsonb_value", From bfcb58daa677fe028c11b9677411bdcdc55ac6c2 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Fri, 3 May 2024 08:53:04 +0000 Subject: [PATCH 08/20] fix(spanner): fix code --- samples/samples/snippets.py | 2 -- tests/system/_helpers.py | 4 +--- tests/unit/test__helpers.py | 6 ++++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index d9964f02ab..5c1036b6ea 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -36,8 +36,6 @@ get_proto_message, get_proto_enum, ) -from google.type import expr_pb2 -from google.iam.v1 import policy_pb2 from google.protobuf import field_mask_pb2 # type: ignore from testdata import singer_pb2 diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index 3a708e0247..b62d453512 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -65,9 +65,7 @@ ) ) -PROTO_COLUMNS_DDL_STATEMENTS = ( - [] if USE_EMULATOR else _fixtures.PROTO_COLUMNS_DDL_STATEMENTS -) +PROTO_COLUMNS_DDL_STATEMENTS = _fixtures.PROTO_COLUMNS_DDL_STATEMENTS retry_true = retry.RetryResult(operator.truth) retry_false = retry.RetryResult(operator.not_) diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 68129c86b2..3bebeb1fae 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -539,9 +539,10 @@ def test_w_float32(self): VALUE = 3.14159 field_type = Type(code=TypeCode.FLOAT32) + field_name = "float32_column" value_pb = Value(number_value=VALUE) - self.assertEqual(self._callFUT(value_pb, field_type), VALUE) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), VALUE) def test_w_float32_str(self): from google.cloud.spanner_v1 import Type, TypeCode @@ -549,10 +550,11 @@ def test_w_float32_str(self): VALUE = "3.14159" field_type = Type(code=TypeCode.FLOAT32) + field_name = "float32_str_column" value_pb = Value(string_value=VALUE) expected_value = 3.14159 - self.assertEqual(self._callFUT(value_pb, field_type), expected_value) + self.assertEqual(self._callFUT(value_pb, field_type, field_name), expected_value) def test_w_date(self): import datetime From 5ff317c3565cc05dbf9cdf4420ebb77386ff47d4 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Fri, 3 May 2024 08:55:37 +0000 Subject: [PATCH 09/20] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/unit/test__helpers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 3bebeb1fae..11adec6ac9 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -554,7 +554,9 @@ def test_w_float32_str(self): value_pb = Value(string_value=VALUE) expected_value = 3.14159 - self.assertEqual(self._callFUT(value_pb, field_type, field_name), expected_value) + self.assertEqual( + self._callFUT(value_pb, field_type, field_name), expected_value + ) def test_w_date(self): import datetime From 02b3d5bbfd92ec47841ee884880611af76f014a3 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Fri, 3 May 2024 11:49:42 +0000 Subject: [PATCH 10/20] fix(spanner): skip emulator due to b/338557401 --- tests/_fixtures.py | 2 ++ tests/system/test_session_api.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/tests/_fixtures.py b/tests/_fixtures.py index 7e42a2eaa4..b250306909 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -91,6 +91,8 @@ ) PRIMARY KEY (CartId); """ +# TODO: Add Proto Bundle DDL statement in EMULATOR_DDL once b/338557401 +# is fixed. EMULATOR_DDL = """\ CREATE TABLE contacts ( contact_id INT64, diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index bbe6000aba..2ef4c8eb5b 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -1543,6 +1543,9 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): sd._check_row_data(after, all_data_rows) +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="b/338557401" +) def test_read_w_index( shared_instance, database_operation_timeout, From 01c3b2a136e9ec91bcd4f682fe5c02484dcbcb10 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Fri, 3 May 2024 11:52:09 +0000 Subject: [PATCH 11/20] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/system/test_session_api.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 2ef4c8eb5b..39190ed41a 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -1543,9 +1543,7 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): sd._check_row_data(after, all_data_rows) -@pytest.mark.skipif( - _helpers.USE_EMULATOR, reason="b/338557401" -) +@pytest.mark.skipif(_helpers.USE_EMULATOR, reason="b/338557401") def test_read_w_index( shared_instance, database_operation_timeout, From a7d60a106c5166491fbb761e0d4131a357a714b5 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Fri, 3 May 2024 13:47:28 +0000 Subject: [PATCH 12/20] fix(spanner): remove samples --- samples/samples/snippets.py | 348 +------------------------------ samples/samples/snippets_test.py | 101 --------- 2 files changed, 2 insertions(+), 447 deletions(-) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 5c1036b6ea..a5f8d8653f 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -31,13 +31,8 @@ from google.cloud import spanner from google.cloud.spanner_admin_instance_v1.types import spanner_instance_admin from google.cloud.spanner_v1 import DirectedReadOptions, param_types -from google.cloud.spanner_v1.data_types import ( - JsonObject, - get_proto_message, - get_proto_enum, -) +from google.cloud.spanner_v1.data_types import JsonObject from google.protobuf import field_mask_pb2 # type: ignore -from testdata import singer_pb2 OPERATION_TIMEOUT_SECONDS = 240 @@ -360,57 +355,6 @@ def create_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_create_database_with_default_leader] -# [START spanner_create_database_with_proto_descriptor] -def create_database_with_proto_descriptor(instance_id, database_id): - """Creates a database with proto descriptors and tables with proto columns for sample data.""" - import os - - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "testdata/descriptors.pb") - - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - - # reads proto descriptor file as bytes - proto_descriptor_file = open(filename, "rb") - proto_descriptor = proto_descriptor_file.read() - - database = instance.database( - database_id, - ddl_statements=[ - """CREATE PROTO BUNDLE ( - spanner.examples.music.SingerInfo, - spanner.examples.music.Genre, - )""", - """CREATE TABLE Singers ( - SingerId INT64 NOT NULL, - FirstName STRING(1024), - LastName STRING(1024), - SingerInfo spanner.examples.music.SingerInfo, - SingerGenre spanner.examples.music.Genre, - SingerInfoArray ARRAY, - SingerGenreArray ARRAY, - ) PRIMARY KEY (SingerId)""", - ], - proto_descriptors=proto_descriptor, - ) - - operation = database.create() - - print("Waiting for operation to complete...") - operation.result(OPERATION_TIMEOUT_SECONDS) - proto_descriptor_file.close() - - print( - "Created database {} with proto descriptors on instance {}".format( - database_id, instance_id - ) - ) - - -# [END spanner_create_database_with_proto_descriptor] - - # [START spanner_update_database_with_default_leader] def update_database_with_default_leader(instance_id, database_id, default_leader): """Updates a database with tables with a default leader.""" @@ -441,51 +385,6 @@ def update_database_with_default_leader(instance_id, database_id, default_leader # [END spanner_update_database_with_default_leader] -# [START spanner_update_database_with_proto_descriptor] -def update_database_with_proto_descriptor(instance_id, database_id): - """Updates a database with tables with a default leader.""" - import os - - dirname = os.path.dirname(__file__) - filename = os.path.join(dirname, "testdata/descriptors.pb") - - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - - database = instance.database(database_id) - proto_descriptor_file = open(filename, "rb") - proto_descriptor = proto_descriptor_file.read() - - operation = database.update_ddl( - [ - """CREATE PROTO BUNDLE ( - spanner.examples.music.SingerInfo, - spanner.examples.music.Genre, - )""", - """CREATE TABLE Singers ( - SingerId INT64 NOT NULL, - FirstName STRING(1024), - LastName STRING(1024), - SingerInfo spanner.examples.music.SingerInfo, - SingerGenre spanner.examples.music.Genre, - SingerInfoArray ARRAY, - SingerGenreArray ARRAY, - ) PRIMARY KEY (SingerId)""", - ], - proto_descriptors=proto_descriptor, - ) - print("Waiting for operation to complete...") - operation.result(OPERATION_TIMEOUT_SECONDS) - proto_descriptor_file.close() - - database.reload() - - print("Database {} updated with proto descriptors".format(database.name)) - - -# [END spanner_update_database_with_proto_descriptor] - - # [START spanner_get_database_ddl] def get_database_ddl(instance_id, database_id): """Gets the database DDL statements.""" @@ -2798,213 +2697,6 @@ def enable_fine_grained_access( # [END spanner_enable_fine_grained_access] -# [START spanner_insert_proto_columns_data] -def insert_proto_columns_data(instance_id, database_id): - """Inserts sample proto column data into the given database. - - The database and table must already exist and can be created using - `create_database`. - """ - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - singer_info = singer_pb2.SingerInfo() - singer_info.singer_id = 2 - singer_info.birth_date = "February" - singer_info.nationality = "Country2" - singer_info.genre = singer_pb2.Genre.FOLK - - singer_info_array = [singer_info] - singer_genre_array = [singer_pb2.Genre.FOLK] - - with database.batch() as batch: - batch.insert( - table="Singers", - columns=( - "SingerId", - "FirstName", - "LastName", - "SingerInfo", - "SingerGenre", - "SingerInfoArray", - "SingerGenreArray", - ), - values=[ - ( - 2, - "Marc", - "Richards", - singer_info, - singer_pb2.Genre.ROCK, - singer_info_array, - singer_genre_array, - ), - (3, "Catalina", "Smith", None, None, None, None), - ], - ) - - print("Inserted data.") - - -# [END spanner_insert_proto_columns_data] - - -# [START spanner_insert_proto_columns_data_using_dml] -def insert_proto_columns_data_using_dml(instance_id, database_id): - """Inserts sample proto column data into the given database using a DML statement.""" - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - singer_info = singer_pb2.SingerInfo() - singer_info.singer_id = 1 - singer_info.birth_date = "January" - singer_info.nationality = "Country1" - singer_info.genre = singer_pb2.Genre.ROCK - - singer_info_array = [singer_info, None] - singer_genre_array = [singer_pb2.Genre.ROCK, None] - - def insert_singers_with_proto_column(transaction): - row_ct = transaction.execute_update( - "INSERT INTO Singers (SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray," - " SingerGenreArray) " - " VALUES (1, 'Virginia', 'Watson', @singerInfo, @singerGenre, @singerInfoArray, @singerGenreArray)", - params={ - "singerInfo": singer_info, - "singerGenre": singer_pb2.Genre.ROCK, - "singerInfoArray": singer_info_array, - "singerGenreArray": singer_genre_array, - }, - param_types={ - "singerInfo": param_types.ProtoMessage(singer_info), - "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), - "singerInfoArray": param_types.Array( - param_types.ProtoMessage(singer_info) - ), - "singerGenreArray": param_types.Array( - param_types.ProtoEnum(singer_pb2.Genre) - ), - }, - ) - - print("{} record(s) inserted.".format(row_ct)) - - database.run_in_transaction(insert_singers_with_proto_column) - - -# [END spanner_insert_proto_columns_data_using_dml] - - -# [START spanner_read_proto_columns_data] -def read_proto_columns_data(instance_id, database_id): - """Reads sample proto column data from the database.""" - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - with database.snapshot() as snapshot: - keyset = spanner.KeySet(all_=True) - results = snapshot.read( - table="Singers", - columns=( - "SingerId", - "FirstName", - "LastName", - "SingerInfo", - "SingerGenre", - "SingerInfoArray", - "SingerGenreArray", - ), - keyset=keyset, - column_info={ - "SingerInfo": singer_pb2.SingerInfo(), - "SingerGenre": singer_pb2.Genre, - "SingerInfoArray": singer_pb2.SingerInfo(), - "SingerGenreArray": singer_pb2.Genre, - }, - ) - - for row in results: - print( - "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " - "SingerGenreArray: {}".format(*row) - ) - - -# [END spanner_read_proto_columns_data] - - -# [START spanner_read_proto_columns_data_using_dql] -def read_proto_columns_data_using_dql(instance_id, database_id): - """Queries sample proto column data from the database using SQL.""" - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - with database.snapshot() as snapshot: - results = snapshot.execute_sql( - "SELECT SingerId, FirstName, LastName, SingerInfo, SingerGenre, SingerInfoArray, SingerGenreArray FROM Singers", - column_info={ - "SingerInfo": singer_pb2.SingerInfo(), - "SingerGenre": singer_pb2.Genre, - "SingerInfoArray": singer_pb2.SingerInfo(), - "SingerGenreArray": singer_pb2.Genre, - }, - ) - - for row in results: - print( - "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, , SingerInfoArray: {}, " - "SingerGenreArray: {}".format(*row) - ) - - -# [END spanner_read_proto_columns_data_using_dql] - - -def read_proto_columns_data_using_helper_method(instance_id, database_id): - """Reads sample proto column data from the database.""" - spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) - database = instance.database(database_id) - - with database.snapshot() as snapshot: - keyset = spanner.KeySet(all_=True) - results = snapshot.read( - table="Singers", - columns=( - "SingerId", - "FirstName", - "LastName", - "SingerInfo", - "SingerGenre", - "SingerInfoArray", - "SingerGenreArray", - ), - keyset=keyset, - ) - - for row in results: - singer_info_proto_msg = get_proto_message(row[3], singer_pb2.SingerInfo()) - singer_genre_proto_enum = get_proto_enum(row[4], singer_pb2.Genre) - singer_info_list = get_proto_message(row[5], singer_pb2.SingerInfo()) - singer_genre_list = get_proto_enum(row[6], singer_pb2.Genre) - print( - "SingerId: {}, FirstName: {}, LastName: {}, SingerInfo: {}, SingerGenre: {}, SingerInfoArray: {}, " - "SingerGenreArray: {}".format( - row[0], - row[1], - row[2], - singer_info_proto_msg, - singer_genre_proto_enum, - singer_info_list, - singer_genre_list, - ) - ) - - # [START spanner_create_table_with_foreign_key_delete_cascade] def create_table_with_foreign_key_delete_cascade(instance_id, database_id): """Creates a table with foreign key delete cascade action""" @@ -3464,7 +3156,6 @@ def create_instance_with_autoscaling_config(instance_id): subparsers = parser.add_subparsers(dest="command") subparsers.add_parser("create_instance", help=create_instance.__doc__) subparsers.add_parser("create_database", help=create_database.__doc__) - subparsers.add_parser("get_database_ddl", help=get_database_ddl.__doc__) subparsers.add_parser("insert_data", help=insert_data.__doc__) subparsers.add_parser("batch_write", help=batch_write.__doc__) subparsers.add_parser("delete_data", help=delete_data.__doc__) @@ -3582,28 +3273,7 @@ def create_instance_with_autoscaling_config(instance_id): subparsers.add_parser("create_sequence", help=create_sequence.__doc__) subparsers.add_parser("alter_sequence", help=alter_sequence.__doc__) subparsers.add_parser("drop_sequence", help=drop_sequence.__doc__) - subparsers.add_parser( - "create_database_with_proto_descriptor", - help=create_database_with_proto_descriptor.__doc__, - ) - subparsers.add_parser( - "insert_proto_columns_data_using_dml", - help=insert_proto_columns_data_using_dml.__doc__, - ) - subparsers.add_parser( - "insert_proto_columns_data", help=insert_proto_columns_data.__doc__ - ) - subparsers.add_parser( - "read_proto_columns_data", help=read_proto_columns_data.__doc__ - ) - subparsers.add_parser( - "read_proto_columns_data_using_helper_method", - help=read_proto_columns_data_using_helper_method.__doc__, - ) - subparsers.add_parser( - "read_proto_columns_data_using_dql", - help=read_proto_columns_data_using_dql.__doc__, - ) + enable_fine_grained_access_parser = subparsers.add_parser( "enable_fine_grained_access", help=enable_fine_grained_access.__doc__ ) @@ -3625,8 +3295,6 @@ def create_instance_with_autoscaling_config(instance_id): create_instance(args.instance_id) elif args.command == "create_database": create_database(args.instance_id, args.database_id) - elif args.command == "get_database_ddl": - get_database_ddl(args.instance_id, args.database_id) elif args.command == "insert_data": insert_data(args.instance_id, args.database_id) elif args.command == "batch_write": @@ -3759,15 +3427,3 @@ def create_instance_with_autoscaling_config(instance_id): set_custom_timeout_and_retry(args.instance_id, args.database_id) elif args.command == "create_instance_with_autoscaling_config": create_instance_with_autoscaling_config(args.instance_id) - elif args.command == "create_database_with_proto_descriptor": - create_database_with_proto_descriptor(args.instance_id, args.database_id) - elif args.command == "insert_proto_columns_data_using_dml": - insert_proto_columns_data_using_dml(args.instance_id, args.database_id) - elif args.command == "insert_proto_columns_data": - insert_proto_columns_data(args.instance_id, args.database_id) - elif args.command == "read_proto_columns_data": - read_proto_columns_data(args.instance_id, args.database_id) - elif args.command == "read_proto_columns_data_using_helper_method": - read_proto_columns_data_using_helper_method(args.instance_id, args.database_id) - elif args.command == "read_proto_columns_data_using_dql": - read_proto_columns_data_using_dql(args.instance_id, args.database_id) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 7652956429..b19784d453 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -44,25 +44,6 @@ INTERLEAVE IN PARENT Singers ON DELETE CASCADE """ -CREATE_TABLE_SINGERS_PROTO = """\ -CREATE TABLE Singers ( -SingerId INT64 NOT NULL, -FirstName STRING(1024), -LastName STRING(1024), -SingerInfo spanner.examples.music.SingerInfo, -SingerGenre spanner.examples.music.Genre, -SingerInfoArray ARRAY, -SingerGenreArray ARRAY, -) PRIMARY KEY (SingerId) -""" - -CREATE_PROTO_BUNDLE = """\ -CREATE PROTO BUNDLE ( - spanner.examples.music.SingerInfo, - spanner.examples.music.Genre, - ) -""" - retry_429 = RetryErrors(exceptions.ResourceExhausted, delay=15) @@ -122,15 +103,6 @@ def database_ddl(): return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS] -@pytest.fixture(scope="module") -def database_ddl_for_proto_columns(): - """Sequence of DDL statements used to set up the database for proto columns. - - Sample testcase modules can override as needed. - """ - return [CREATE_PROTO_BUNDLE, CREATE_TABLE_SINGERS_PROTO] - - @pytest.fixture(scope="module") def default_leader(): """Default leader for multi-region instances.""" @@ -216,13 +188,6 @@ def test_create_database_with_encryption_config( assert kms_key_name in out -def test_create_database_with_proto_descriptor(capsys, instance_id, database_id): - snippets.create_database_with_proto_descriptor(instance_id, database_id) - out, _ = capsys.readouterr() - assert database_id in out - assert instance_id in out - - def test_get_instance_config(capsys): instance_config = "nam6" snippets.get_instance_config(instance_config) @@ -836,72 +801,6 @@ def test_list_database_roles(capsys, instance_id, sample_database): assert "new_parent" in out -def test_update_database_with_proto_descriptor(capsys, sample_instance, create_database_id): - # We have to create a new database here as proto samples also have Singers table and this will clash. - sample_instance.database(create_database_id).create().result(240) - snippets.update_database_with_proto_descriptor(sample_instance.instance_id, create_database_id) - out, _ = capsys.readouterr() - assert "updated with proto descriptors" in out - database = sample_instance.database(create_database_id) - database.drop() - - -@pytest.mark.dependency(name="insert_proto_columns_data_dml") -def test_insert_proto_columns_data_using_dml(capsys, instance_id, sample_database_for_proto_columns): - snippets.insert_proto_columns_data_using_dml( - instance_id, sample_database_for_proto_columns.database_id - ) - out, _ = capsys.readouterr() - assert "record(s) inserted" in out - - -@pytest.mark.dependency(name="insert_proto_columns_data") -def test_insert_proto_columns_data(capsys, instance_id, sample_database_for_proto_columns): - snippets.insert_proto_columns_data(instance_id, sample_database_for_proto_columns.database_id) - out, _ = capsys.readouterr() - assert "Inserted data" in out - - -@pytest.mark.dependency( - depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] -) -def test_read_proto_columns_data_using_dql(capsys, instance_id, sample_database_for_proto_columns): - snippets.read_proto_columns_data_using_dql(instance_id, sample_database_for_proto_columns.database_id) - out, _ = capsys.readouterr() - - assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out - assert "SingerId: 2, FirstName: Marc, LastName: Richards" in out - assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out - - -@pytest.mark.dependency( - depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] -) -def test_read_proto_columns_data(capsys, instance_id, sample_database_for_proto_columns): - snippets.read_proto_columns_data(instance_id, sample_database_for_proto_columns.database_id) - out, _ = capsys.readouterr() - - assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out - assert "SingerId: 2, FirstName: Marc, LastName: Richards" in out - assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out - - -@pytest.mark.dependency( - depends=["insert_proto_columns_data_dml, insert_proto_columns_data"] -) -def test_read_proto_columns_data_using_helper_method( - capsys, instance_id, sample_database_for_proto_columns -): - snippets.read_proto_columns_data_using_helper_method( - instance_id, sample_database_for_proto_columns.database_id - ) - out, _ = capsys.readouterr() - - assert "SingerId: 1, FirstName: Virginia, LastName: Watson" in out - assert "SingerId: 2, FirstName: Marc, LastName: Richards" in out - assert "SingerId: 3, FirstName: Catalina, LastName: Smith" in out - - @pytest.mark.dependency(name="create_table_with_foreign_key_delete_cascade") def test_create_table_with_foreign_key_delete_cascade( capsys, instance_id, sample_database From 1fa16054b0a67b6c4f74587fe025b0bcf257b52e Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Fri, 3 May 2024 13:49:15 +0000 Subject: [PATCH 13/20] fix(spanner): update coverage --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 9b71c55a7a..ea452e3e93 100644 --- a/noxfile.py +++ b/noxfile.py @@ -313,7 +313,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=99") + session.run("coverage", "report", "--show-missing", "--fail-under=98") session.run("coverage", "erase") From f658ba077c339c127b08fb730be65bcc76e7936d Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Fri, 3 May 2024 13:51:43 +0000 Subject: [PATCH 14/20] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index ea452e3e93..9b71c55a7a 100644 --- a/noxfile.py +++ b/noxfile.py @@ -313,7 +313,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=98") + session.run("coverage", "report", "--show-missing", "--fail-under=99") session.run("coverage", "erase") From eea661fa95c46c19aace54f2e240c6009cbd3ab3 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH <57220027+harshachinta@users.noreply.github.com> Date: Mon, 6 May 2024 15:19:03 +0530 Subject: [PATCH 15/20] chore(spanner): update coverage --- owlbot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/owlbot.py b/owlbot.py index 2785c226ec..4ef3686ce8 100644 --- a/owlbot.py +++ b/owlbot.py @@ -126,7 +126,7 @@ def get_staging_dirs( templated_files = common.py_library( microgenerator=True, samples=True, - cov_level=99, + cov_level=98, split_system_tests=True, system_test_extras=["tracing"], ) From 9a838e0a4711fb372f8ff239e48571fcf663acd5 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Mon, 6 May 2024 09:51:11 +0000 Subject: [PATCH 16/20] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- noxfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 9b71c55a7a..ea452e3e93 100644 --- a/noxfile.py +++ b/noxfile.py @@ -313,7 +313,7 @@ def cover(session): test runs (not system test runs), and then erases coverage data. """ session.install("coverage", "pytest-cov") - session.run("coverage", "report", "--show-missing", "--fail-under=99") + session.run("coverage", "report", "--show-missing", "--fail-under=98") session.run("coverage", "erase") From a76980c24090b8116469fdd63524d157931b1d85 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Sat, 11 May 2024 14:07:12 +0000 Subject: [PATCH 17/20] fix(spanner): add samples and update proto schema --- samples/samples/conftest.py | 48 +++-- samples/samples/snippets.py | 243 ++++++++++++++++++++++++ samples/samples/snippets_test.py | 58 ++++++ samples/samples/testdata/README.md | 5 + samples/samples/testdata/descriptors.pb | Bin 251 -> 251 bytes samples/samples/testdata/singer.proto | 2 +- samples/samples/testdata/singer_pb2.py | 16 +- tests/_fixtures.py | 16 +- tests/system/test_database_api.py | 2 +- tests/system/testdata/descriptors.pb | Bin 251 -> 251 bytes 10 files changed, 340 insertions(+), 50 deletions(-) create mode 100644 samples/samples/testdata/README.md diff --git a/samples/samples/conftest.py b/samples/samples/conftest.py index d2475cad56..b843f98298 100644 --- a/samples/samples/conftest.py +++ b/samples/samples/conftest.py @@ -199,6 +199,29 @@ def database_id(): return "my-database-id" +@pytest.fixture(scope="module") +def proto_columns_database( + spanner_client, + sample_instance, + database_id, + proto_columns_database_ddl, + database_dialect, +): + if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: + sample_database = sample_instance.database( + database_id, + ddl_statements=proto_columns_database_ddl, + ) + + if not sample_database.exists(): + operation = sample_database.create() + operation.result(OPERATION_TIMEOUT_SECONDS) + + yield sample_database + + sample_database.drop() + + @pytest.fixture(scope="module") def bit_reverse_sequence_database_id(): """Id for the database used in bit reverse sequence samples. @@ -258,31 +281,6 @@ def sample_database( sample_database.drop() -@pytest.fixture(scope="module") -def sample_database_for_proto_columns( - spanner_client, - sample_instance, - database_id, - database_ddl_for_proto_columns, - database_dialect, - proto_descriptor_file, -): - if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: - sample_database = sample_instance.database( - database_id, - ddl_statements=database_ddl_for_proto_columns, - proto_descriptors=proto_descriptor_file, - ) - - if not sample_database.exists(): - operation = sample_database.create() - operation.result(OPERATION_TIMEOUT_SECONDS) - - yield sample_database - - sample_database.drop() - - @pytest.fixture(scope="module") def bit_reverse_sequence_database( spanner_client, sample_instance, bit_reverse_sequence_database_id, database_dialect diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index a5f8d8653f..058b21ef91 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -33,6 +33,7 @@ from google.cloud.spanner_v1 import DirectedReadOptions, param_types from google.cloud.spanner_v1.data_types import JsonObject from google.protobuf import field_mask_pb2 # type: ignore +from testdata import singer_pb2 OPERATION_TIMEOUT_SECONDS = 240 @@ -3144,6 +3145,228 @@ def create_instance_with_autoscaling_config(instance_id): # [END spanner_create_instance_with_autoscaling_config] +def add_proto_type_columns(instance_id, database_id): + # [START spanner_add_proto_type_columns] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + """Adds a new Proto Message column and Proto Enum column to the Singers table.""" + import os + + dirname = os.path.dirname(__file__) + filename = os.path.join(dirname, "testdata/descriptors.pb") + + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + + database = instance.database(database_id) + proto_descriptor_file = open(filename, "rb") + proto_descriptor = proto_descriptor_file.read() + + operation = database.update_ddl( + [ + """CREATE PROTO BUNDLE ( + examples.spanner.music.SingerInfo, + examples.spanner.music.Genre, + )""", + "ALTER TABLE Singers ADD COLUMN SingerInfo examples.spanner.music.SingerInfo", + "ALTER TABLE Singers ADD COLUMN SingerInfoArray ARRAY", + "ALTER TABLE Singers ADD COLUMN SingerGenre examples.spanner.music.Genre", + "ALTER TABLE Singers ADD COLUMN SingerGenreArray ARRAY", + ], + proto_descriptors=proto_descriptor, + ) + print("Waiting for operation to complete...") + operation.result(OPERATION_TIMEOUT_SECONDS) + proto_descriptor_file.close() + + database.reload() + + print( + 'Altered table "Singers" on database {} on instance {} with proto descriptors.'.format( + database_id, instance_id + ) + ) + # [END spanner_add_proto_type_columns] + + +def update_data_with_proto_types(instance_id, database_id): + # [START spanner_update_data_with_proto_types] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + """Updates Singers tables in the database with the ProtoMessage + and ProtoEnum column. + + This updates the `SingerInfo`, `SingerInfoArray`, `SingerGenre` and + `SingerGenreArray` columns which must be created before + running this sample. You can add the column by running the + `add_proto_type_columns` sample or by running this DDL statement + against your database: + + ALTER TABLE Singers ADD COLUMN SingerInfo examples.spanner.music.SingerInfo\n + ALTER TABLE Singers ADD COLUMN SingerInfoArray ARRAY\n + ALTER TABLE Singers ADD COLUMN SingerGenre examples.spanner.music.Genre\n + ALTER TABLE Singers ADD COLUMN SingerGenreArray ARRAY\n + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 2 + singer_info.birth_date = "February" + singer_info.nationality = "Country2" + singer_info.genre = singer_pb2.Genre.FOLK + + singer_info_array = [singer_info] + + singer_genre_array = [singer_pb2.Genre.FOLK] + + with database.batch() as batch: + batch.update( + table="Singers", + columns=( + "SingerId", + "SingerInfo", + "SingerInfoArray", + "SingerGenre", + "SingerGenreArray", + ), + values=[ + ( + 2, + singer_info, + singer_info_array, + singer_pb2.Genre.FOLK, + singer_genre_array, + ), + (3, None, None, None, None), + ], + ) + + print("Data updated.") + # [END spanner_update_data_with_proto_types] + + +def update_data_with_proto_types_with_dml(instance_id, database_id): + # [START spanner_update_data_with_proto_types_with_dml] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + """Updates Singers tables in the database with the ProtoMessage + and ProtoEnum column. + + This updates the `SingerInfo`, `SingerInfoArray`, `SingerGenre` and `SingerGenreArray` columns which must be created before + running this sample. You can add the column by running the + `add_proto_type_columns` sample or by running this DDL statement + against your database: + + ALTER TABLE Singers ADD COLUMN SingerInfo examples.spanner.music.SingerInfo\n + ALTER TABLE Singers ADD COLUMN SingerInfoArray ARRAY\n + ALTER TABLE Singers ADD COLUMN SingerGenre examples.spanner.music.Genre\n + ALTER TABLE Singers ADD COLUMN SingerGenreArray ARRAY\n + """ + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + singer_info = singer_pb2.SingerInfo() + singer_info.singer_id = 1 + singer_info.birth_date = "January" + singer_info.nationality = "Country1" + singer_info.genre = singer_pb2.Genre.ROCK + + singer_info_array = [singer_info, None] + + singer_genre_array = [singer_pb2.Genre.ROCK, None] + + def update_singers_with_proto_types(transaction): + row_ct = transaction.execute_update( + "UPDATE Singers " + "SET SingerInfo = @singerInfo, SingerInfoArray=@singerInfoArray, " + "SingerGenre=@singerGenre, SingerGenreArray=@singerGenreArray " + "WHERE SingerId = 1", + params={ + "singerInfo": singer_info, + "singerInfoArray": singer_info_array, + "singerGenre": singer_pb2.Genre.ROCK, + "singerGenreArray": singer_genre_array, + }, + param_types={ + "singerInfo": param_types.ProtoMessage(singer_info), + "singerInfoArray": param_types.Array( + param_types.ProtoMessage(singer_info) + ), + "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), + "singerGenreArray": param_types.Array( + param_types.ProtoEnum(singer_pb2.Genre) + ), + }, + ) + + print("{} record(s) updated.".format(row_ct)) + + database.run_in_transaction(update_singers_with_proto_types) + + def update_singers_with_proto_field(transaction): + row_ct = transaction.execute_update( + "UPDATE Singers " + "SET SingerInfo.nationality = @singerNationality " + "WHERE SingerId = 1", + params={ + "singerNationality": "Country2", + }, + param_types={ + "singerNationality": param_types.STRING, + }, + ) + + print("{} record(s) updated.".format(row_ct)) + + database.run_in_transaction(update_singers_with_proto_field) + # [END spanner_update_data_with_proto_types_with_dml] + + +def query_data_with_proto_types_parameter(instance_id, database_id): + # [START spanner_query_with_proto_types_parameter] + # instance_id = "your-spanner-instance" + # database_id = "your-spanner-db-id" + + spanner_client = spanner.Client() + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + with database.snapshot() as snapshot: + results = snapshot.execute_sql( + "SELECT SingerId, SingerInfo, SingerInfo.nationality, SingerInfoArray, " + "SingerGenre, SingerGenreArray FROM Singers " + "WHERE SingerInfo.Nationality=@country " + "and SingerGenre=@singerGenre", + params={ + "country": "Country2", + "singerGenre": singer_pb2.Genre.FOLK, + }, + param_types={ + "country": param_types.STRING, + "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), + }, + column_info={ + "SingerInfo": singer_pb2.SingerInfo(), + "SingerInfoArray": singer_pb2.SingerInfo(), + "SingerGenre": singer_pb2.Genre, + "SingerGenreArray": singer_pb2.Genre, + }, + ) + + for row in results: + print( + "SingerId: {}, SingerInfo: {}, SingerGenre: {}, " + "SingerInfoArray: {}, SingerGenreArray: {}".format(*row) + ) + # [END spanner_query_with_proto_types_parameter] + + if __name__ == "__main__": # noqa: C901 parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -3288,6 +3511,18 @@ def create_instance_with_autoscaling_config(instance_id): subparsers.add_parser( "set_custom_timeout_and_retry", help=set_custom_timeout_and_retry.__doc__ ) + subparsers.add_parser("add_proto_type_columns", help=add_proto_type_columns.__doc__) + subparsers.add_parser( + "update_data_with_proto_types", help=update_data_with_proto_types.__doc__ + ) + subparsers.add_parser( + "update_data_with_proto_types_with_dml", + help=update_data_with_proto_types_with_dml.__doc__, + ) + subparsers.add_parser( + "query_data_with_proto_types_parameter", + help=query_data_with_proto_types_parameter.__doc__, + ) args = parser.parse_args() @@ -3427,3 +3662,11 @@ def create_instance_with_autoscaling_config(instance_id): set_custom_timeout_and_retry(args.instance_id, args.database_id) elif args.command == "create_instance_with_autoscaling_config": create_instance_with_autoscaling_config(args.instance_id) + elif args.command == "add_proto_type_columns": + add_proto_type_columns(args.instance_id, args.database_id) + elif args.command == "update_data_with_proto_types": + update_data_with_proto_types(args.instance_id, args.database_id) + elif args.command == "update_data_with_proto_types_with_dml": + update_data_with_proto_types_with_dml(args.instance_id, args.database_id) + elif args.command == "query_data_with_proto_types_parameter": + query_data_with_proto_types_parameter(args.instance_id, args.database_id) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index b19784d453..865010c8bb 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -44,6 +44,14 @@ INTERLEAVE IN PARENT Singers ON DELETE CASCADE """ +CREATE_TABLE_SINGERS_ = """\ +CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + FirstName STRING(1024), + LastName STRING(1024), + ) PRIMARY KEY (SingerId) +""" + retry_429 = RetryErrors(exceptions.ResourceExhausted, delay=15) @@ -103,6 +111,15 @@ def database_ddl(): return [CREATE_TABLE_SINGERS, CREATE_TABLE_ALBUMS] +@pytest.fixture(scope="module") +def proto_columns_database_ddl(): + """Sequence of DDL statements used to set up the database for proto columns. + + Sample testcase modules can override as needed. + """ + return [CREATE_TABLE_SINGERS_, CREATE_TABLE_ALBUMS] + + @pytest.fixture(scope="module") def default_leader(): """Default leader for multi-region instances.""" @@ -885,3 +902,44 @@ def test_set_custom_timeout_and_retry(capsys, instance_id, sample_database): snippets.set_custom_timeout_and_retry(instance_id, sample_database.database_id) out, _ = capsys.readouterr() assert "SingerId: 1, AlbumId: 1, AlbumTitle: Total Junk" in out + + +@pytest.mark.dependency( + name="add_proto_types_column", +) +def test_add_proto_types_column(capsys, instance_id, proto_columns_database): + snippets.add_proto_type_columns(instance_id, proto_columns_database.database_id) + out, _ = capsys.readouterr() + assert 'Altered table "Singers" on database ' in out + + snippets.insert_data(instance_id, proto_columns_database.database_id) + + +@pytest.mark.dependency( + name="update_data_with_proto_message", depends=["add_proto_types_column"] +) +def test_update_data_with_proto_types(capsys, instance_id, proto_columns_database): + snippets.update_data_with_proto_types( + instance_id, proto_columns_database.database_id + ) + out, _ = capsys.readouterr() + assert "Data updated" in out + + snippets.update_data_with_proto_types_with_dml( + instance_id, proto_columns_database.database_id + ) + out, _ = capsys.readouterr() + assert "1 record(s) updated." in out + + +@pytest.mark.dependency( + depends=["add_proto_types_column", "update_data_with_proto_message"] +) +def test_query_data_with_proto_types_parameter( + capsys, instance_id, proto_columns_database +): + snippets.query_data_with_proto_types_parameter( + instance_id, proto_columns_database.database_id + ) + out, _ = capsys.readouterr() + assert "SingerId: 2, SingerInfo: singer_id: 2" in out diff --git a/samples/samples/testdata/README.md b/samples/samples/testdata/README.md new file mode 100644 index 0000000000..b4ff1b649b --- /dev/null +++ b/samples/samples/testdata/README.md @@ -0,0 +1,5 @@ +#### To generate singer_pb2.py and descriptos.pb file from singer.proto using `protoc` +```shell +cd samples/samples +protoc --proto_path=testdata/ --include_imports --descriptor_set_out=testdata/descriptors.pb --python_out=testdata/ testdata/singer.proto +``` diff --git a/samples/samples/testdata/descriptors.pb b/samples/samples/testdata/descriptors.pb index 3ebb79420b3ffd2ca3b3b57433a4a10bfa22b675..d4c018f3a3c21b18f68820eeab130d8195064e81 100644 GIT binary patch delta 63 zcmey(_?uCg>jxtjPjO~mdTNngK~a85zK~dIMPhD2PHM4UaY15UUTV=qjxtjPjO~mdTNngK~a85zK~dPL1JDWkegbOm|KvOT0Bv?RRBY5hcf`^ C3>YK; diff --git a/samples/samples/testdata/singer.proto b/samples/samples/testdata/singer.proto index 8dde1bccae..60276440d7 100644 --- a/samples/samples/testdata/singer.proto +++ b/samples/samples/testdata/singer.proto @@ -1,6 +1,6 @@ syntax = "proto2"; -package spanner.examples.music; +package examples.spanner.music; message SingerInfo { optional int64 singer_id = 1; diff --git a/samples/samples/testdata/singer_pb2.py b/samples/samples/testdata/singer_pb2.py index cdb44c74af..b29049c79a 100644 --- a/samples/samples/testdata/singer_pb2.py +++ b/samples/samples/testdata/singer_pb2.py @@ -1,18 +1,4 @@ # -*- coding: utf-8 -*- -# Copyright 2022 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # Generated by the protocol buffer compiler. DO NOT EDIT! # source: singer.proto """Generated protocol buffer code.""" @@ -27,7 +13,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csinger.proto\x12\x16spanner.examples.music\"v\n\nSingerInfo\x12\x11\n\tsinger_id\x18\x01 \x01(\x03\x12\x12\n\nbirth_date\x18\x02 \x01(\t\x12\x13\n\x0bnationality\x18\x03 \x01(\t\x12,\n\x05genre\x18\x04 \x01(\x0e\x32\x1d.spanner.examples.music.Genre*.\n\x05Genre\x12\x07\n\x03POP\x10\x00\x12\x08\n\x04JAZZ\x10\x01\x12\x08\n\x04\x46OLK\x10\x02\x12\x08\n\x04ROCK\x10\x03') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0csinger.proto\x12\x16\x65xamples.spanner.music\"v\n\nSingerInfo\x12\x11\n\tsinger_id\x18\x01 \x01(\x03\x12\x12\n\nbirth_date\x18\x02 \x01(\t\x12\x13\n\x0bnationality\x18\x03 \x01(\t\x12,\n\x05genre\x18\x04 \x01(\x0e\x32\x1d.examples.spanner.music.Genre*.\n\x05Genre\x12\x07\n\x03POP\x10\x00\x12\x08\n\x04JAZZ\x10\x01\x12\x08\n\x04\x46OLK\x10\x02\x12\x08\n\x04ROCK\x10\x03') _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'singer_pb2', globals()) diff --git a/tests/_fixtures.py b/tests/_fixtures.py index b250306909..a2941d19da 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -29,8 +29,8 @@ PRIMARY KEY (contact_id, phone_type), INTERLEAVE IN PARENT contacts ON DELETE CASCADE; CREATE PROTO BUNDLE ( - spanner.examples.music.SingerInfo, - spanner.examples.music.Genre, + examples.spanner.music.SingerInfo, + examples.spanner.music.Genre, ); CREATE TABLE all_types ( pkey INT64 NOT NULL, @@ -52,10 +52,10 @@ numeric_array ARRAY, json_value JSON, json_array ARRAY, - proto_message_value spanner.examples.music.SingerInfo, - proto_message_array ARRAY, - proto_enum_value spanner.examples.music.Genre, - proto_enum_array ARRAY, + proto_message_value examples.spanner.music.SingerInfo, + proto_message_array ARRAY, + proto_enum_value examples.spanner.music.Genre, + proto_enum_array ARRAY, ) PRIMARY KEY (pkey); CREATE TABLE counters ( @@ -200,8 +200,8 @@ singer_id INT64 NOT NULL, first_name STRING(1024), last_name STRING(1024), - singer_info spanner.examples.music.SingerInfo, - singer_genre spanner.examples.music.Genre, ) + singer_info examples.spanner.music.SingerInfo, + singer_genre examples.spanner.music.Genre, ) PRIMARY KEY (singer_id); CREATE INDEX SingerByGenre ON singers(singer_genre) STORING (first_name, last_name); """ diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 8ddb097366..244fccd069 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -905,7 +905,7 @@ def test_create_table_with_proto_columns( ): proto_cols_db_id = _helpers.unique_id("proto-columns") extra_ddl = [ - "CREATE PROTO BUNDLE (spanner.examples.music.SingerInfo, spanner.examples.music.Genre,)" + "CREATE PROTO BUNDLE (examples.spanner.music.SingerInfo, examples.spanner.music.Genre,)" ] proto_cols_database = shared_instance.database( diff --git a/tests/system/testdata/descriptors.pb b/tests/system/testdata/descriptors.pb index 3ebb79420b3ffd2ca3b3b57433a4a10bfa22b675..d4c018f3a3c21b18f68820eeab130d8195064e81 100644 GIT binary patch delta 63 zcmey(_?uCg>jxtjPjO~mdTNngK~a85zK~dIMPhD2PHM4UaY15UUTV=qjxtjPjO~mdTNngK~a85zK~dPL1JDWkegbOm|KvOT0Bv?RRBY5hcf`^ C3>YK; From 19e97d113a4ab36a31a0ec1acbf5ba066a647bc0 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Mon, 13 May 2024 05:35:22 +0000 Subject: [PATCH 18/20] fix(spanner): update samples database and emulator DDL --- samples/samples/conftest.py | 4 ++-- samples/samples/snippets_test.py | 5 +++++ tests/_fixtures.py | 4 ++++ tests/system/test_session_api.py | 1 - 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/samples/samples/conftest.py b/samples/samples/conftest.py index b843f98298..9810a41d45 100644 --- a/samples/samples/conftest.py +++ b/samples/samples/conftest.py @@ -203,13 +203,13 @@ def database_id(): def proto_columns_database( spanner_client, sample_instance, - database_id, + proto_columns_database_id, proto_columns_database_ddl, database_dialect, ): if database_dialect == DatabaseDialect.GOOGLE_STANDARD_SQL: sample_database = sample_instance.database( - database_id, + proto_columns_database_id, ddl_statements=proto_columns_database_ddl, ) diff --git a/samples/samples/snippets_test.py b/samples/samples/snippets_test.py index 865010c8bb..909305a65a 100644 --- a/samples/samples/snippets_test.py +++ b/samples/samples/snippets_test.py @@ -102,6 +102,11 @@ def default_leader_database_id(): return f"leader_db_{uuid.uuid4().hex[:10]}" +@pytest.fixture(scope="module") +def proto_columns_database_id(): + return f"test-db-proto-{uuid.uuid4().hex[:10]}" + + @pytest.fixture(scope="module") def database_ddl(): """Sequence of DDL statements used to set up the database. diff --git a/tests/_fixtures.py b/tests/_fixtures.py index a2941d19da..cbdc110ab4 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -106,6 +106,10 @@ phone_number STRING(1024) ) PRIMARY KEY (contact_id, phone_type), INTERLEAVE IN PARENT contacts ON DELETE CASCADE; +CREATE PROTO BUNDLE ( + examples.spanner.music.SingerInfo, + examples.spanner.music.Genre, + ); CREATE TABLE all_types ( pkey INT64 NOT NULL, int_value INT64, diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 39190ed41a..bbe6000aba 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -1543,7 +1543,6 @@ def test_multiuse_snapshot_read_isolation_exact_staleness(sessions_database): sd._check_row_data(after, all_data_rows) -@pytest.mark.skipif(_helpers.USE_EMULATOR, reason="b/338557401") def test_read_w_index( shared_instance, database_operation_timeout, From 60afda1b77b7b89003f49d7bb0f5195faf2208ab Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Tue, 14 May 2024 07:33:06 +0000 Subject: [PATCH 19/20] fix(spanner): update admin test to use autogenerated interfaces --- samples/samples/snippets.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index 058b21ef91..c9e2178894 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -3151,20 +3151,24 @@ def add_proto_type_columns(instance_id, database_id): # database_id = "your-spanner-db-id" """Adds a new Proto Message column and Proto Enum column to the Singers table.""" + import os + from google.cloud.spanner_admin_database_v1.types import spanner_database_admin dirname = os.path.dirname(__file__) filename = os.path.join(dirname, "testdata/descriptors.pb") spanner_client = spanner.Client() - instance = spanner_client.instance(instance_id) + database_admin_api = spanner_client.database_admin_api - database = instance.database(database_id) proto_descriptor_file = open(filename, "rb") proto_descriptor = proto_descriptor_file.read() - operation = database.update_ddl( - [ + request = spanner_database_admin.UpdateDatabaseDdlRequest( + database=database_admin_api.database_path( + spanner_client.project, instance_id, database_id + ), + statements=[ """CREATE PROTO BUNDLE ( examples.spanner.music.SingerInfo, examples.spanner.music.Genre, @@ -3176,12 +3180,13 @@ def add_proto_type_columns(instance_id, database_id): ], proto_descriptors=proto_descriptor, ) + + operation = database_admin_api.update_database_ddl(request) + print("Waiting for operation to complete...") operation.result(OPERATION_TIMEOUT_SECONDS) proto_descriptor_file.close() - database.reload() - print( 'Altered table "Singers" on database {} on instance {} with proto descriptors.'.format( database_id, instance_id From d0504793e032fd72b4ff081d8cd066266e33de21 Mon Sep 17 00:00:00 2001 From: Sri Harsha CH Date: Wed, 15 May 2024 07:05:27 +0000 Subject: [PATCH 20/20] fix(spanner): comment refactoring --- google/cloud/spanner_v1/_helpers.py | 8 +++++++- google/cloud/spanner_v1/session.py | 16 ++++++++++++++-- google/cloud/spanner_v1/snapshot.py | 16 ++++++++++++++-- samples/samples/snippets.py | 12 ++++++++++-- tests/_fixtures.py | 2 -- 5 files changed, 45 insertions(+), 9 deletions(-) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index f77faf1ed7..a1d6a60cb0 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -254,7 +254,13 @@ def _parse_value_pb(value_pb, field_type, field_name, column_info=None): :param field_name: column name :type column_info: dict - :param column_info: (Optional) dict of column name and column information + :param column_info: (Optional) dict of column name and column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. :rtype: varies on field_type :returns: value extracted from value_pb diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index 3552dee8b3..52994e58e2 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -248,7 +248,13 @@ def read(self, table, columns, keyset, index="", limit=0, column_info=None): :param limit: (Optional) maximum number of rows to return :type column_info: dict - :param column_info: (Optional) dict of mapping between column names and additional column information + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -308,7 +314,13 @@ def execute_sql( :param timeout: (Optional) The timeout for this request. :type column_info: dict - :param column_info: (Optional) dict of mapping between column names and additional column information + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. diff --git a/google/cloud/spanner_v1/snapshot.py b/google/cloud/spanner_v1/snapshot.py index 1257cb2782..3bc1a746bd 100644 --- a/google/cloud/spanner_v1/snapshot.py +++ b/google/cloud/spanner_v1/snapshot.py @@ -233,7 +233,13 @@ def read( or regions should be used for non-transactional reads or queries. :type column_info: dict - :param column_info: (Optional) dict of mapping between column names and additional column information + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. :rtype: :class:`~google.cloud.spanner_v1.streamed.StreamedResultSet` :returns: a result set instance which can be used to consume rows. @@ -407,7 +413,13 @@ def execute_sql( or regions should be used for non-transactional reads or queries. :type column_info: dict - :param column_info: (Optional) dict of mapping between column names and additional column information + :param column_info: (Optional) dict of mapping between column names and additional column information. + An object where column names as keys and custom objects as corresponding + values for deserialization. It's specifically useful for data types like + protobuf where deserialization logic is on user-specific code. When provided, + the custom object enables deserialization of backend-received column data. + If not provided, data remains serialized as bytes for Proto Messages and + integer for Proto Enums. :raises ValueError: for reuse of single-use snapshots, or if a transaction ID is diff --git a/samples/samples/snippets.py b/samples/samples/snippets.py index c9e2178894..e7c76685d3 100644 --- a/samples/samples/snippets.py +++ b/samples/samples/snippets.py @@ -3356,6 +3356,12 @@ def query_data_with_proto_types_parameter(instance_id, database_id): "country": param_types.STRING, "singerGenre": param_types.ProtoEnum(singer_pb2.Genre), }, + # column_info is an optional parameter and is used to deserialize + # the proto message and enum object back from bytearray and + # int respectively. + # If column_info is not passed for proto messages and enums, then + # the data types for these columns will be bytes and int + # respectively. column_info={ "SingerInfo": singer_pb2.SingerInfo(), "SingerInfoArray": singer_pb2.SingerInfo(), @@ -3366,8 +3372,10 @@ def query_data_with_proto_types_parameter(instance_id, database_id): for row in results: print( - "SingerId: {}, SingerInfo: {}, SingerGenre: {}, " - "SingerInfoArray: {}, SingerGenreArray: {}".format(*row) + "SingerId: {}, SingerInfo: {}, SingerInfoNationality: {}, " + "SingerInfoArray: {}, SingerGenre: {}, SingerGenreArray: {}".format( + *row + ) ) # [END spanner_query_with_proto_types_parameter] diff --git a/tests/_fixtures.py b/tests/_fixtures.py index cbdc110ab4..7a80adc00a 100644 --- a/tests/_fixtures.py +++ b/tests/_fixtures.py @@ -91,8 +91,6 @@ ) PRIMARY KEY (CartId); """ -# TODO: Add Proto Bundle DDL statement in EMULATOR_DDL once b/338557401 -# is fixed. EMULATOR_DDL = """\ CREATE TABLE contacts ( contact_id INT64,