Skip to content

Commit eb9c905

Browse files
authored
Unimported LRO type annotation fix (#330)
Create wrapped Protos in two stages: 1) Load messages and enums 2) Load services and methods This allows methods to reference _all_ messages in the entire API surface without having to rely on explicit imports. This is a workaround for a common case of #318 where an LRO response or metadata type is referenced as a string in the method annotation but is not a visible, imported type.
1 parent 5d8e90b commit eb9c905

File tree

2 files changed

+263
-99
lines changed

2 files changed

+263
-99
lines changed

packages/gapic-generator/gapic/schema/api.py

Lines changed: 187 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import os
2424
import sys
2525
from itertools import chain
26-
from typing import Callable, Container, Dict, FrozenSet, Mapping, Sequence, Set, Tuple
26+
from typing import Callable, Container, Dict, FrozenSet, Mapping, Optional, Sequence, Set, Tuple
2727

2828
from google.api_core import exceptions # type: ignore
2929
from google.longrunning import operations_pb2 # type: ignore
@@ -56,10 +56,13 @@ def __getattr__(self, name: str):
5656
return getattr(self.file_pb2, name)
5757

5858
@classmethod
59-
def build(cls, file_descriptor: descriptor_pb2.FileDescriptorProto,
60-
file_to_generate: bool, naming: api_naming.Naming,
61-
prior_protos: Mapping[str, 'Proto'] = None,
62-
) -> 'Proto':
59+
def build(
60+
cls, file_descriptor: descriptor_pb2.FileDescriptorProto,
61+
file_to_generate: bool, naming: api_naming.Naming,
62+
opts: options.Options = options.Options(),
63+
prior_protos: Mapping[str, 'Proto'] = None,
64+
load_services: bool = True
65+
) -> 'Proto':
6366
"""Build and return a Proto instance.
6467
6568
Args:
@@ -71,12 +74,18 @@ def build(cls, file_descriptor: descriptor_pb2.FileDescriptorProto,
7174
with the API.
7275
prior_protos (~.Proto): Previous, already processed protos.
7376
These are needed to look up messages in imported protos.
77+
load_services (bool): Toggle whether the proto file should
78+
load its services. Not doing so enables a two-pass fix for
79+
LRO response and metadata types in certain situations.
7480
"""
75-
return _ProtoBuilder(file_descriptor,
76-
file_to_generate=file_to_generate,
77-
naming=naming,
78-
prior_protos=prior_protos or {},
79-
).proto
81+
return _ProtoBuilder(
82+
file_descriptor,
83+
file_to_generate=file_to_generate,
84+
naming=naming,
85+
opts=opts,
86+
prior_protos=prior_protos or {},
87+
load_services=load_services
88+
).proto
8089

8190
@cached_property
8291
def enums(self) -> Mapping[str, wrappers.EnumType]:
@@ -184,10 +193,13 @@ class API:
184193
subpackage_view: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
185194

186195
@classmethod
187-
def build(cls,
188-
file_descriptors: Sequence[descriptor_pb2.FileDescriptorProto],
189-
package: str = '',
190-
opts: options.Options = options.Options()) -> 'API':
196+
def build(
197+
cls,
198+
file_descriptors: Sequence[descriptor_pb2.FileDescriptorProto],
199+
package: str = '',
200+
opts: options.Options = options.Options(),
201+
prior_protos: Mapping[str, 'Proto'] = None,
202+
) -> 'API':
191203
"""Build the internal API schema based on the request.
192204
193205
Args:
@@ -199,6 +211,9 @@ def build(cls,
199211
Protos with packages outside this list are considered imports
200212
rather than explicit targets.
201213
opts (~.options.Options): CLI options passed to the generator.
214+
prior_protos (~.Proto): Previous, already processed protos.
215+
These are needed to look up messages in imported protos.
216+
Primarily used for testing.
202217
"""
203218
# Save information about the overall naming for this API.
204219
naming = api_naming.Naming.build(*filter(
@@ -221,16 +236,43 @@ def disambiguate_keyword_fname(
221236

222237
# Iterate over each FileDescriptorProto and fill out a Proto
223238
# object describing it, and save these to the instance.
224-
protos: Dict[str, Proto] = {}
239+
#
240+
# The first pass gathers messages and enums but NOT services or methods.
241+
# This is a workaround for a limitation in protobuf annotations for
242+
# long running operations: the annotations are strings that reference
243+
# message types but do not require a proto import.
244+
# This hack attempts to address a common case where API authors,
245+
# not wishing to generate an 'unused import' warning,
246+
# don't import the proto file defining the real response or metadata
247+
# type into the proto file that defines an LRO.
248+
# We just load all the APIs types first and then
249+
# load the services and methods with the full scope of types.
250+
pre_protos: Dict[str, Proto] = dict(prior_protos or {})
225251
for fd in file_descriptors:
226-
fd.name = disambiguate_keyword_fname(fd.name, protos)
227-
protos[fd.name] = _ProtoBuilder(
252+
fd.name = disambiguate_keyword_fname(fd.name, pre_protos)
253+
pre_protos[fd.name] = Proto.build(
228254
file_descriptor=fd,
229255
file_to_generate=fd.package.startswith(package),
230256
naming=naming,
231257
opts=opts,
232-
prior_protos=protos,
233-
).proto
258+
prior_protos=pre_protos,
259+
# Ugly, ugly hack.
260+
load_services=False,
261+
)
262+
263+
# Second pass uses all the messages and enums defined in the entire API.
264+
# This allows LRO returning methods to see all the types in the API,
265+
# bypassing the above missing import problem.
266+
protos: Dict[str, Proto] = {
267+
name: Proto.build(
268+
file_descriptor=proto.file_pb2,
269+
file_to_generate=proto.file_to_generate,
270+
naming=naming,
271+
opts=opts,
272+
prior_protos=pre_protos,
273+
)
274+
for name, proto in pre_protos.items()
275+
}
234276

235277
# Done; return the API.
236278
return cls(naming=naming, all_protos=protos)
@@ -319,11 +361,15 @@ class _ProtoBuilder:
319361
"""
320362
EMPTY = descriptor_pb2.SourceCodeInfo.Location()
321363

322-
def __init__(self, file_descriptor: descriptor_pb2.FileDescriptorProto,
323-
file_to_generate: bool,
324-
naming: api_naming.Naming,
325-
opts: options.Options = options.Options(),
326-
prior_protos: Mapping[str, Proto] = None):
364+
def __init__(
365+
self,
366+
file_descriptor: descriptor_pb2.FileDescriptorProto,
367+
file_to_generate: bool,
368+
naming: api_naming.Naming,
369+
opts: options.Options = options.Options(),
370+
prior_protos: Mapping[str, Proto] = None,
371+
load_services: bool = True
372+
):
327373
self.proto_messages: Dict[str, wrappers.MessageType] = {}
328374
self.proto_enums: Dict[str, wrappers.EnumType] = {}
329375
self.proto_services: Dict[str, wrappers.Service] = {}
@@ -388,7 +434,7 @@ def __init__(self, file_descriptor: descriptor_pb2.FileDescriptorProto,
388434
# This prevents us from generating common services (e.g. LRO) when
389435
# they are being used as an import just to get types declared in the
390436
# same files.
391-
if file_to_generate:
437+
if file_to_generate and load_services:
392438
self._load_children(file_descriptor.service, self._load_service,
393439
address=self.address, path=(6,))
394440
# TODO(lukesneeringer): oneofs are on path 7.
@@ -522,6 +568,116 @@ def _get_fields(self,
522568
# Done; return the answer.
523569
return answer
524570

571+
def _get_retry_and_timeout(
572+
self,
573+
service_address: metadata.Address,
574+
meth_pb: descriptor_pb2.MethodDescriptorProto
575+
) -> Tuple[Optional[wrappers.RetryInfo], Optional[float]]:
576+
"""Returns the retry and timeout configuration of a method if it exists.
577+
578+
Args:
579+
service_address (~.metadata.Address): An address object for the
580+
service, denoting the location of these methods.
581+
meth_pb (~.descriptor_pb2.MethodDescriptorProto): A
582+
protobuf method objects.
583+
584+
Returns:
585+
Tuple[Optional[~.wrappers.RetryInfo], Optional[float]]: The retry
586+
and timeout information for the method if it exists.
587+
"""
588+
589+
# If we got a gRPC service config, get the appropriate retry
590+
# and timeout information from it.
591+
retry = None
592+
timeout = None
593+
594+
# This object should be a dictionary that conforms to the
595+
# gRPC service config proto:
596+
# Repo: https://github.com/grpc/grpc-proto/
597+
# Filename: grpc/service_config/service_config.proto
598+
#
599+
# We only care about a small piece, so we are just leaving
600+
# it as a dictionary and parsing accordingly.
601+
if self.opts.retry:
602+
# The gRPC service config uses a repeated `name` field
603+
# with a particular format, which we match against.
604+
# This defines the expected selector for *this* method.
605+
selector = {
606+
'service': '{package}.{service_name}'.format(
607+
package='.'.join(service_address.package),
608+
service_name=service_address.name,
609+
),
610+
'method': meth_pb.name,
611+
}
612+
613+
# Find the method config that applies to us, if any.
614+
mc = next((c for c in self.opts.retry.get('methodConfig', [])
615+
if selector in c.get('name')), None)
616+
if mc:
617+
# Set the timeout according to this method config.
618+
if mc.get('timeout'):
619+
timeout = self._to_float(mc['timeout'])
620+
621+
# Set the retry according to this method config.
622+
if 'retryPolicy' in mc:
623+
r = mc['retryPolicy']
624+
retry = wrappers.RetryInfo(
625+
max_attempts=r.get('maxAttempts', 0),
626+
initial_backoff=self._to_float(
627+
r.get('initialBackoff', '0s'),
628+
),
629+
max_backoff=self._to_float(r.get('maxBackoff', '0s')),
630+
backoff_multiplier=r.get('backoffMultiplier', 0.0),
631+
retryable_exceptions=frozenset(
632+
exceptions.exception_class_for_grpc_status(
633+
getattr(grpc.StatusCode, code),
634+
)
635+
for code in r.get('retryableStatusCodes', [])
636+
),
637+
)
638+
639+
return retry, timeout
640+
641+
def _maybe_get_lro(
642+
self,
643+
service_address: metadata.Address,
644+
meth_pb: descriptor_pb2.MethodDescriptorProto
645+
) -> Optional[wrappers.OperationInfo]:
646+
"""Determines whether a method is a Long Running Operation (aka LRO)
647+
and, if it is, return an OperationInfo that includes the response
648+
and metadata types.
649+
650+
Args:
651+
service_address (~.metadata.Address): An address object for the
652+
service, denoting the location of these methods.
653+
meth_pb (~.descriptor_pb2.MethodDescriptorProto): A
654+
protobuf method objects.
655+
656+
Returns:
657+
Optional[~.wrappers.OperationInfo]: The info for the long-running
658+
operation, if the passed method is an LRO.
659+
"""
660+
lro = None
661+
662+
# If the output type is google.longrunning.Operation, we use
663+
# a specialized object in its place.
664+
if meth_pb.output_type.endswith('google.longrunning.Operation'):
665+
op = meth_pb.options.Extensions[operations_pb2.operation_info]
666+
if not op.response_type or not op.metadata_type:
667+
raise TypeError(
668+
f'rpc {meth_pb.name} returns a google.longrunning.'
669+
'Operation, but is missing a response type or '
670+
'metadata type.',
671+
)
672+
response_key = service_address.resolve(op.response_type)
673+
metadata_key = service_address.resolve(op.metadata_type)
674+
lro = wrappers.OperationInfo(
675+
response_type=self.api_messages[response_key],
676+
metadata_type=self.api_messages[metadata_key],
677+
)
678+
679+
return lro
680+
525681
def _get_methods(self,
526682
methods: Sequence[descriptor_pb2.MethodDescriptorProto],
527683
service_address: metadata.Address, path: Tuple[int, ...],
@@ -542,84 +698,16 @@ def _get_methods(self,
542698
"""
543699
# Iterate over the methods and collect them into a dictionary.
544700
answer: Dict[str, wrappers.Method] = collections.OrderedDict()
545-
for meth_pb, i in zip(methods, range(0, sys.maxsize)):
546-
lro = None
547-
548-
# If the output type is google.longrunning.Operation, we use
549-
# a specialized object in its place.
550-
if meth_pb.output_type.endswith('google.longrunning.Operation'):
551-
op = meth_pb.options.Extensions[operations_pb2.operation_info]
552-
if not op.response_type or not op.metadata_type:
553-
raise TypeError(
554-
f'rpc {meth_pb.name} returns a google.longrunning.'
555-
'Operation, but is missing a response type or '
556-
'metadata type.',
557-
)
558-
lro = wrappers.OperationInfo(
559-
response_type=self.api_messages[service_address.resolve(
560-
op.response_type,
561-
)],
562-
metadata_type=self.api_messages[service_address.resolve(
563-
op.metadata_type,
564-
)],
565-
)
566-
567-
# If we got a gRPC service config, get the appropriate retry
568-
# and timeout information from it.
569-
retry = None
570-
timeout = None
571-
572-
# This object should be a dictionary that conforms to the
573-
# gRPC service config proto:
574-
# Repo: https://github.com/grpc/grpc-proto/
575-
# Filename: grpc/service_config/service_config.proto
576-
#
577-
# We only care about a small piece, so we are just leaving
578-
# it as a dictionary and parsing accordingly.
579-
if self.opts.retry:
580-
# The gRPC service config uses a repeated `name` field
581-
# with a particular format, which we match against.
582-
# This defines the expected selector for *this* method.
583-
selector = {
584-
'service': '{package}.{service_name}'.format(
585-
package='.'.join(service_address.package),
586-
service_name=service_address.name,
587-
),
588-
'method': meth_pb.name,
589-
}
590-
591-
# Find the method config that applies to us, if any.
592-
mc = next((i for i in self.opts.retry.get('methodConfig', [])
593-
if selector in i.get('name')), None)
594-
if mc:
595-
# Set the timeout according to this method config.
596-
if mc.get('timeout'):
597-
timeout = self._to_float(mc['timeout'])
598-
599-
# Set the retry according to this method config.
600-
if 'retryPolicy' in mc:
601-
r = mc['retryPolicy']
602-
retry = wrappers.RetryInfo(
603-
max_attempts=r.get('maxAttempts', 0),
604-
initial_backoff=self._to_float(
605-
r.get('initialBackoff', '0s'),
606-
),
607-
max_backoff=self._to_float(
608-
r.get('maxBackoff', '0s'),
609-
),
610-
backoff_multiplier=r.get('backoffMultiplier', 0.0),
611-
retryable_exceptions=frozenset(
612-
exceptions.exception_class_for_grpc_status(
613-
getattr(grpc.StatusCode, code),
614-
)
615-
for code in r.get('retryableStatusCodes', [])
616-
),
617-
)
701+
for i, meth_pb in enumerate(methods):
702+
retry, timeout = self._get_retry_and_timeout(
703+
service_address,
704+
meth_pb
705+
)
618706

619707
# Create the method wrapper object.
620708
answer[meth_pb.name] = wrappers.Method(
621709
input=self.api_messages[meth_pb.input_type.lstrip('.')],
622-
lro=lro,
710+
lro=self._maybe_get_lro(service_address, meth_pb),
623711
method_pb=meth_pb,
624712
meta=metadata.Metadata(
625713
address=service_address.child(meth_pb.name, path + (i,)),

0 commit comments

Comments
 (0)