Skip to content

Commit 3703e7f

Browse files
authored
Correctly find the sample template when run via protoc (#198)
Add a sample_template parameter to sample generation Adjust tests to pass the parameter
1 parent 6762647 commit 3703e7f

File tree

7 files changed

+130
-64
lines changed

7 files changed

+130
-64
lines changed

packages/gapic-generator/gapic/generator/generator.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,18 @@ def get_response(self, api_schema: api.API) -> CodeGeneratorResponse:
9797
api_schema=api_schema,
9898
))
9999

100-
output_files.update(self._generate_samples_and_manifest(api_schema))
100+
output_files.update(self._generate_samples_and_manifest(
101+
api_schema,
102+
self._env.get_template(sample_templates[0]),
103+
))
101104

102105
# Return the CodeGeneratorResponse output.
103106
return CodeGeneratorResponse(file=[i for i in output_files.values()])
104107

105108
def _generate_samples_and_manifest(
106-
self,
107-
api_schema: api.API
109+
self,
110+
api_schema: api.API,
111+
sample_template: jinja2.Template,
108112
) -> Dict[str, CodeGeneratorResponse.File]:
109113
"""Generate samples and samplegen manifest for the API.
110114
@@ -152,7 +156,11 @@ def _generate_samples_and_manifest(
152156
str(spec).encode('utf8')).hexdigest()[:8]
153157
spec["id"] += f"_{spec_hash}"
154158

155-
sample = samplegen.generate_sample(spec, self._env, api_schema)
159+
sample = samplegen.generate_sample(
160+
spec,
161+
api_schema,
162+
sample_template,
163+
)
156164

157165
fpath = spec["id"] + ".py"
158166
fpath_to_spec_and_rendered[os.path.join(out_dir, fpath)] = (spec,

packages/gapic-generator/gapic/samplegen/manifest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,15 @@ def transform_path(fpath):
8484
yaml.Collection(
8585
name="samples",
8686
elements=[
87-
[ # type: ignore
87+
[
8888
# Mypy doesn't correctly intuit the type of the
8989
# "region_tag" conditional expression.
9090
yaml.Alias(environment.anchor_name or ""),
9191
yaml.KeyVal("sample", sample["id"]),
9292
yaml.KeyVal(
9393
"path", transform_path(fpath)
9494
),
95-
(yaml.KeyVal("region_tag", sample["region_tag"])
95+
(yaml.KeyVal("region_tag", sample["region_tag"]) # type: ignore
9696
if "region_tag" in sample else
9797
yaml.Null),
9898
]

packages/gapic-generator/gapic/samplegen/samplegen.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import time
2222

2323
from gapic.samplegen_utils import types
24-
from gapic.schema import (api, wrappers)
24+
from gapic.schema import wrappers
2525

2626
from collections import (defaultdict, namedtuple, ChainMap as chainmap)
2727
from typing import (ChainMap, Dict, List, Mapping, Optional, Tuple)
@@ -139,8 +139,8 @@ def __init__(self, method: wrappers.Method):
139139
# and whether it's an enum or a message or a primitive type.
140140
# The method call response isn't a field, so construct an artificial
141141
# field that wraps the response.
142-
{ # type: ignore
143-
"$resp": MockField(response_type, False)
142+
{
143+
"$resp": MockField(response_type, False) # type: ignore
144144
}
145145
)
146146

@@ -481,8 +481,10 @@ def _validate_format(self, body: List[str]):
481481
num_prints = fmt_str.count("%s")
482482
if num_prints != len(body) - 1:
483483
raise types.MismatchedFormatSpecifier(
484-
"Expected {} expresssions in format string but received {}".format(
485-
num_prints, len(body) - 1
484+
"Expected {} expresssions in format string '{}' but found {}".format(
485+
num_prints,
486+
fmt_str,
487+
len(body) - 1
486488
)
487489
)
488490

@@ -502,7 +504,7 @@ def _validate_define(self, body: str):
502504
"""
503505
# Note: really checking for safety would be equivalent to
504506
# re-implementing the python interpreter.
505-
m = re.match(r"^([a-zA-Z]\w*)=([^=]+)$", body)
507+
m = re.match(r"^([a-zA-Z_]\w*) *= *([^=]+)$", body)
506508
if not m:
507509
raise types.BadAssignment(f"Bad assignment statement: {body}")
508510

@@ -654,27 +656,23 @@ def _validate_loop(self, loop):
654656
}
655657

656658

657-
def generate_sample(sample,
658-
env: jinja2.environment.Environment,
659-
api_schema: api.API,
660-
template_name: str = DEFAULT_TEMPLATE_NAME) -> str:
659+
def generate_sample(
660+
sample,
661+
api_schema,
662+
sample_template: jinja2.Template
663+
) -> str:
661664
"""Generate a standalone, runnable sample.
662665
663666
Rendering and writing the rendered output is left for the caller.
664667
665668
Args:
666669
sample (Any): A definition for a single sample generated from parsed yaml.
667-
env (jinja2.environment.Environment): The jinja environment used to generate
668-
the filled template for the sample.
669670
api_schema (api.API): The schema that defines the API to which the sample belongs.
670-
template_name (str): An optional override for the name of the template
671-
used to generate the sample.
671+
sample_template (jinja2.Template): The template representing a generic sample.
672672
673673
Returns:
674674
str: The rendered sample.
675675
"""
676-
sample_template = env.get_template(template_name)
677-
678676
service_name = sample["service"]
679677
service = api_schema.services.get(service_name)
680678
if not service:
@@ -701,7 +699,12 @@ def generate_sample(sample,
701699
return sample_template.render(
702700
file_header=FILE_HEADER,
703701
sample=sample,
704-
imports=[],
702+
imports=[
703+
"from google import auth",
704+
"from google.auth import credentials",
705+
],
705706
calling_form=calling_form,
706707
calling_form_enum=types.CallingForm,
708+
api=api_schema,
709+
service=service,
707710
)

packages/gapic-generator/gapic/templates/examples/feature_fragments.j2

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,32 +191,32 @@ client.{{ sample.rpc|snake_case }}({{ render_request_params(sample.request) }})
191191
{# it's just easier to set up client side streaming and other things from outside this macro. #}
192192
{% macro render_calling_form(method_invocation_text, calling_form, calling_form_enum, response_statements ) %}
193193
{% if calling_form == calling_form_enum.Request %}
194-
response = {{ method_invocation_text }}
194+
response = {{ method_invocation_text|trim }}
195195
{% for statement in response_statements %}
196196
{{ dispatch_statement(statement)|trim }}
197197
{% endfor %}
198198
{% elif calling_form == calling_form_enum.RequestPagedAll %}
199-
page_result = {{ method_invocation_text }}
199+
page_result = {{ method_invocation_text|trim }}
200200
for response in page_result:
201201
{% for statement in response_statements %}
202202
{{ dispatch_statement(statement)|trim }}
203203
{% endfor %}
204204
{% elif calling_form == calling_form_enum.RequestPaged %}
205-
page_result = {{ method_invocation_text}}
205+
page_result = {{ method_invocation_text|trim }}
206206
for page in page_result.pages():
207207
for response in page:
208208
{% for statement in response_statements %}
209209
{{ dispatch_statement(statement)|trim }}
210210
{% endfor %}
211211
{% elif calling_form in [calling_form_enum.RequestStreamingServer,
212212
calling_form_enum.RequestStreamingBidi] %}
213-
stream = {{ method_invocation_text }}
213+
stream = {{ method_invocation_text|trim }}
214214
for response in stream:
215215
{% for statement in response_statements %}
216216
{{ dispatch_statement(statement)|trim }}
217217
{% endfor %}
218218
{% elif calling_form == calling_form_enum.LongRunningRequestPromise %}
219-
operation = {{ method_invocation_text }}
219+
operation = {{ method_invocation_text|trim }}
220220

221221
print("Waiting for operation to complete...")
222222

@@ -237,8 +237,8 @@ def main():
237237

238238
parser = argparse.ArgumentParser()
239239
{% with arg_list = [] %}
240-
{% for request in request_block if request.body -%}
241-
{% for attr in request.body if attr.input_parameter %}
240+
{% for request in request_block if request.body -%}
241+
{% for attr in request.body if attr.input_parameter %}
242242
parser.add_argument("--{{ attr.input_parameter }}",
243243
type=str,
244244
default={{ attr.value }})

packages/gapic-generator/gapic/templates/examples/sample.py.j2

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,24 @@
2020
{# callingFormEnum #}
2121
{# Note: this sample template is WILDLY INACCURATE AND INCOMPLETE #}
2222
{# It does not correctly enums, unions, top level attributes, or various other things #}
23-
{% import "feature_fragments.j2" as frags %}
23+
{% import "examples/feature_fragments.j2" as frags %}
2424
{{ frags.sample_header(file_header, sample, calling_form) }}
2525

2626
# [START {{ sample.id }}]
2727
{# python code is responsible for all transformations: all we do here is render #}
2828
{% for import_statement in imports %}
2929
{{ import_statement }}
3030
{% endfor %}
31+
from {{ (api.naming.module_namespace + (api.naming.versioned_module_name,) + service.meta.address.subpackage)|join(".") }}.services.{{ service.name|snake_case }} import {{ service.name }}
3132

3233
{# also need calling form #}
3334
def sample_{{ frags.render_method_name(sample.rpc)|trim -}}({{ frags.print_input_params(sample.request)|trim -}}):
3435
"""{{ sample.description }}"""
3536

36-
client = {{ sample.service.split(".")[-3:-1]|
37-
map("lower")|
38-
join("_") }}.{{ sample.service.split(".")[-1] }}Client()
37+
client = {{ service.name }}(
38+
credentials=credentials.AnonymousCredentials(),
39+
transport="grpc",
40+
)
3941

4042
{{ frags.render_request_setup(sample.request)|indent }}
4143
{% with method_call = frags.render_method_call(sample, calling_form, calling_form_enum) %}

packages/gapic-generator/tests/unit/generator/test_generator.py

Lines changed: 43 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,15 @@ def test_custom_template_directory():
4242
def test_get_response():
4343
g = make_generator()
4444
with mock.patch.object(jinja2.FileSystemLoader, 'list_templates') as lt:
45-
lt.return_value = ['foo/bar/baz.py.j2']
45+
lt.return_value = ['foo/bar/baz.py.j2', 'molluscs/squid/sample.py.j2']
4646
with mock.patch.object(jinja2.Environment, 'get_template') as gt:
4747
gt.return_value = jinja2.Template('I am a template result.')
4848
cgr = g.get_response(api_schema=make_api())
4949
lt.assert_called_once()
50-
gt.assert_called_once()
50+
gt.assert_has_calls([
51+
mock.call('foo/bar/baz.py.j2'),
52+
mock.call('molluscs/squid/sample.py.j2')
53+
])
5154
assert len(cgr.file) == 1
5255
assert cgr.file[0].name == 'foo/bar/baz.py'
5356
assert cgr.file[0].content == 'I am a template result.\n'
@@ -56,24 +59,34 @@ def test_get_response():
5659
def test_get_response_ignores_empty_files():
5760
g = make_generator()
5861
with mock.patch.object(jinja2.FileSystemLoader, 'list_templates') as lt:
59-
lt.return_value = ['foo/bar/baz.py.j2']
62+
lt.return_value = ['foo/bar/baz.py.j2', 'molluscs/squid/sample.py.j2']
6063
with mock.patch.object(jinja2.Environment, 'get_template') as gt:
6164
gt.return_value = jinja2.Template('# Meaningless comment')
6265
cgr = g.get_response(api_schema=make_api())
6366
lt.assert_called_once()
64-
gt.assert_called_once()
67+
gt.assert_has_calls([
68+
mock.call('foo/bar/baz.py.j2'),
69+
mock.call('molluscs/squid/sample.py.j2')
70+
])
6571
assert len(cgr.file) == 0
6672

6773

6874
def test_get_response_ignores_private_files():
6975
g = make_generator()
7076
with mock.patch.object(jinja2.FileSystemLoader, 'list_templates') as lt:
71-
lt.return_value = ['foo/bar/baz.py.j2', 'foo/bar/_base.py.j2']
77+
lt.return_value = [
78+
'foo/bar/baz.py.j2',
79+
'foo/bar/_base.py.j2',
80+
'molluscs/squid/sample.py.j2',
81+
]
7282
with mock.patch.object(jinja2.Environment, 'get_template') as gt:
7383
gt.return_value = jinja2.Template('I am a template result.')
7484
cgr = g.get_response(api_schema=make_api())
7585
lt.assert_called_once()
76-
gt.assert_called_once()
86+
gt.assert_has_calls([
87+
mock.call('foo/bar/baz.py.j2'),
88+
mock.call('molluscs/squid/sample.py.j2')
89+
])
7790
assert len(cgr.file) == 1
7891
assert cgr.file[0].name == 'foo/bar/baz.py'
7992
assert cgr.file[0].content == 'I am a template result.\n'
@@ -82,7 +95,10 @@ def test_get_response_ignores_private_files():
8295
def test_get_response_fails_invalid_file_paths():
8396
g = make_generator()
8497
with mock.patch.object(jinja2.FileSystemLoader, 'list_templates') as lt:
85-
lt.return_value = ['foo/bar/$service/$proto/baz.py.j2']
98+
lt.return_value = [
99+
'foo/bar/$service/$proto/baz.py.j2',
100+
'molluscs/squid/sample.py.j2',
101+
]
86102
with pytest.raises(ValueError) as ex:
87103
g.get_response(api_schema=make_api())
88104

@@ -93,7 +109,10 @@ def test_get_response_fails_invalid_file_paths():
93109
def test_get_response_enumerates_services():
94110
g = make_generator()
95111
with mock.patch.object(jinja2.FileSystemLoader, 'list_templates') as lt:
96-
lt.return_value = ['foo/$service/baz.py.j2']
112+
lt.return_value = [
113+
'foo/$service/baz.py.j2',
114+
'molluscs/squid/sample.py.j2',
115+
]
97116
with mock.patch.object(jinja2.Environment, 'get_template') as gt:
98117
gt.return_value = jinja2.Template('Service: {{ service.name }}')
99118
cgr = g.get_response(api_schema=make_api(make_proto(
@@ -112,7 +131,10 @@ def test_get_response_enumerates_services():
112131
def test_get_response_enumerates_proto():
113132
g = make_generator()
114133
with mock.patch.object(jinja2.FileSystemLoader, 'list_templates') as lt:
115-
lt.return_value = ['foo/$proto.py.j2']
134+
lt.return_value = [
135+
'foo/$proto.py.j2',
136+
'molluscs/squid/sample.py.j2',
137+
]
116138
with mock.patch.object(jinja2.Environment, 'get_template') as gt:
117139
gt.return_value = jinja2.Template('Proto: {{ proto.module_name }}')
118140
cgr = g.get_response(api_schema=make_api(
@@ -146,6 +168,7 @@ def test_get_response_divides_subpackages():
146168
lt.return_value = [
147169
'foo/$sub/types/$proto.py.j2',
148170
'foo/$sub/services/$service.py.j2',
171+
'molluscs/squid/sample.py.j2',
149172
]
150173
with mock.patch.object(jinja2.Environment, 'get_template') as gt:
151174
gt.return_value = jinja2.Template("""
@@ -274,7 +297,11 @@ def test_parse_sample_paths(fs):
274297
@mock.patch(
275298
'time.gmtime',
276299
)
277-
def test_samplegen_config_to_output_files(mock_gmtime, mock_generate_sample, fs):
300+
def test_samplegen_config_to_output_files(
301+
mock_gmtime,
302+
mock_generate_sample,
303+
fs,
304+
):
278305
# These time values are nothing special,
279306
# they just need to be deterministic.
280307
returner = mock.MagicMock()
@@ -303,13 +330,14 @@ def test_samplegen_config_to_output_files(mock_gmtime, mock_generate_sample, fs)
303330
)
304331
)
305332

306-
mock_generate_sample
307-
308333
g = generator.Generator(
309334
options.Options.build(
310335
'samples=samples.yaml',
311336
)
312337
)
338+
# Need to have the sample template visible to the generator.
339+
g._env.loader = jinja2.DictLoader({'sample.py.j2': ''})
340+
313341
api_schema = make_api(naming=naming.Naming(name='Mollusc', version='v6'))
314342
actual_response = g.get_response(api_schema)
315343
expected_response = CodeGeneratorResponse(
@@ -393,6 +421,9 @@ def test_samplegen_id_disambiguation(mock_gmtime, mock_generate_sample, fs):
393421
)
394422
)
395423
g = generator.Generator(options.Options.build('samples=samples.yaml'))
424+
# Need to have the sample template visible to the generator.
425+
g._env.loader = jinja2.DictLoader({'sample.py.j2': ''})
426+
396427
api_schema = make_api(naming=naming.Naming(name='Mollusc', version='v6'))
397428
actual_response = g.get_response(api_schema)
398429
expected_response = CodeGeneratorResponse(

0 commit comments

Comments
 (0)