Skip to content

Commit 80a028a

Browse files
authored
Add mypy annotations for _api_calls.py (#1257)
* Add mypy annotations for _api_calls.py This commit adds mypy annotations for _api_calls.py to make sure that we have no untyped classes and functions. * Include Pieters suggestions
1 parent 32c2902 commit 80a028a

File tree

5 files changed

+60
-25
lines changed

5 files changed

+60
-25
lines changed

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ repos:
1919
additional_dependencies:
2020
- types-requests
2121
- types-python-dateutil
22+
- id: mypy
23+
name: mypy top-level-functions
24+
files: openml/_api_calls.py
25+
additional_dependencies:
26+
- types-requests
27+
- types-python-dateutil
28+
args: [ --disallow-untyped-defs, --disallow-any-generics,
29+
--disallow-any-explicit, --implicit-optional ]
2230
- repo: https://github.com/pycqa/flake8
2331
rev: 6.0.0
2432
hooks:

openml/_api_calls.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import xml
1212
import xmltodict
1313
from urllib3 import ProxyManager
14-
from typing import Dict, Optional, Union
14+
from typing import Dict, Optional, Tuple, Union
1515
import zipfile
1616

1717
import minio
@@ -24,6 +24,9 @@
2424
OpenMLHashException,
2525
)
2626

27+
DATA_TYPE = Dict[str, Union[str, int]]
28+
FILE_ELEMENTS_TYPE = Dict[str, Union[str, Tuple[str, str]]]
29+
2730

2831
def resolve_env_proxies(url: str) -> Optional[str]:
2932
"""Attempt to find a suitable proxy for this url.
@@ -54,7 +57,12 @@ def _create_url_from_endpoint(endpoint: str) -> str:
5457
return url.replace("=", "%3d")
5558

5659

57-
def _perform_api_call(call, request_method, data=None, file_elements=None):
60+
def _perform_api_call(
61+
call: str,
62+
request_method: str,
63+
data: Optional[DATA_TYPE] = None,
64+
file_elements: Optional[FILE_ELEMENTS_TYPE] = None,
65+
) -> str:
5866
"""
5967
Perform an API call at the OpenML server.
6068
@@ -76,8 +84,6 @@ def _perform_api_call(call, request_method, data=None, file_elements=None):
7684
7785
Returns
7886
-------
79-
return_code : int
80-
HTTP return code
8187
return_value : str
8288
Return value of the OpenML server
8389
"""
@@ -257,7 +263,7 @@ def _download_text_file(
257263
return None
258264

259265

260-
def _file_id_to_url(file_id, filename=None):
266+
def _file_id_to_url(file_id: str, filename: Optional[str] = None) -> str:
261267
"""
262268
Presents the URL how to download a given file id
263269
filename is optional
@@ -269,7 +275,9 @@ def _file_id_to_url(file_id, filename=None):
269275
return url
270276

271277

272-
def _read_url_files(url, data=None, file_elements=None):
278+
def _read_url_files(
279+
url: str, data: Optional[DATA_TYPE] = None, file_elements: Optional[FILE_ELEMENTS_TYPE] = None
280+
) -> requests.Response:
273281
"""do a post request to url with data
274282
and sending file_elements as files"""
275283

@@ -288,7 +296,12 @@ def _read_url_files(url, data=None, file_elements=None):
288296
return response
289297

290298

291-
def __read_url(url, request_method, data=None, md5_checksum=None):
299+
def __read_url(
300+
url: str,
301+
request_method: str,
302+
data: Optional[DATA_TYPE] = None,
303+
md5_checksum: Optional[str] = None,
304+
) -> requests.Response:
292305
data = {} if data is None else data
293306
if config.apikey:
294307
data["api_key"] = config.apikey
@@ -306,10 +319,16 @@ def __is_checksum_equal(downloaded_file_binary: bytes, md5_checksum: Optional[st
306319
return md5_checksum == md5_checksum_download
307320

308321

309-
def _send_request(request_method, url, data, files=None, md5_checksum=None):
322+
def _send_request(
323+
request_method: str,
324+
url: str,
325+
data: DATA_TYPE,
326+
files: Optional[FILE_ELEMENTS_TYPE] = None,
327+
md5_checksum: Optional[str] = None,
328+
) -> requests.Response:
310329
n_retries = max(1, config.connection_n_retries)
311330

312-
response = None
331+
response: requests.Response
313332
with requests.Session() as session:
314333
# Start at one to have a non-zero multiplier for the sleep
315334
for retry_counter in range(1, n_retries + 1):
@@ -380,12 +399,12 @@ def human(n: int) -> float:
380399

381400
delay = {"human": human, "robot": robot}[config.retry_policy](retry_counter)
382401
time.sleep(delay)
383-
if response is None:
384-
raise ValueError("This should never happen!")
385402
return response
386403

387404

388-
def __check_response(response, url, file_elements):
405+
def __check_response(
406+
response: requests.Response, url: str, file_elements: Optional[FILE_ELEMENTS_TYPE]
407+
) -> None:
389408
if response.status_code != 200:
390409
raise __parse_server_exception(response, url, file_elements=file_elements)
391410
elif (
@@ -397,7 +416,7 @@ def __check_response(response, url, file_elements):
397416
def __parse_server_exception(
398417
response: requests.Response,
399418
url: str,
400-
file_elements: Dict,
419+
file_elements: Optional[FILE_ELEMENTS_TYPE],
401420
) -> OpenMLServerError:
402421
if response.status_code == 414:
403422
raise OpenMLServerError("URI too long! ({})".format(url))

openml/datasets/functions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import os
66
from pyexpat import ExpatError
7-
from typing import List, Dict, Union, Optional, cast
7+
from typing import List, Dict, Optional, Union, cast
88
import warnings
99

1010
import numpy as np
@@ -867,7 +867,7 @@ def edit_dataset(
867867
raise TypeError("`data_id` must be of type `int`, not {}.".format(type(data_id)))
868868

869869
# compose data edit parameters as xml
870-
form_data = {"data_id": data_id}
870+
form_data = {"data_id": data_id} # type: openml._api_calls.DATA_TYPE
871871
xml = OrderedDict() # type: 'OrderedDict[str, OrderedDict]'
872872
xml["oml:data_edit_parameters"] = OrderedDict()
873873
xml["oml:data_edit_parameters"]["@xmlns:oml"] = "http://openml.org/openml"
@@ -888,7 +888,9 @@ def edit_dataset(
888888
if not xml["oml:data_edit_parameters"][k]:
889889
del xml["oml:data_edit_parameters"][k]
890890

891-
file_elements = {"edit_parameters": ("description.xml", xmltodict.unparse(xml))}
891+
file_elements = {
892+
"edit_parameters": ("description.xml", xmltodict.unparse(xml))
893+
} # type: openml._api_calls.FILE_ELEMENTS_TYPE
892894
result_xml = openml._api_calls._perform_api_call(
893895
"data/edit", "post", data=form_data, file_elements=file_elements
894896
)
@@ -929,7 +931,7 @@ def fork_dataset(data_id: int) -> int:
929931
if not isinstance(data_id, int):
930932
raise TypeError("`data_id` must be of type `int`, not {}.".format(type(data_id)))
931933
# compose data fork parameters
932-
form_data = {"data_id": data_id}
934+
form_data = {"data_id": data_id} # type: openml._api_calls.DATA_TYPE
933935
result_xml = openml._api_calls._perform_api_call("data/fork", "post", data=form_data)
934936
result = xmltodict.parse(result_xml)
935937
data_id = result["oml:data_fork"]["oml:id"]
@@ -949,7 +951,7 @@ def _topic_add_dataset(data_id: int, topic: str):
949951
"""
950952
if not isinstance(data_id, int):
951953
raise TypeError("`data_id` must be of type `int`, not {}.".format(type(data_id)))
952-
form_data = {"data_id": data_id, "topic": topic}
954+
form_data = {"data_id": data_id, "topic": topic} # type: openml._api_calls.DATA_TYPE
953955
result_xml = openml._api_calls._perform_api_call("data/topicadd", "post", data=form_data)
954956
result = xmltodict.parse(result_xml)
955957
data_id = result["oml:data_topic"]["oml:id"]
@@ -970,7 +972,7 @@ def _topic_delete_dataset(data_id: int, topic: str):
970972
"""
971973
if not isinstance(data_id, int):
972974
raise TypeError("`data_id` must be of type `int`, not {}.".format(type(data_id)))
973-
form_data = {"data_id": data_id, "topic": topic}
975+
form_data = {"data_id": data_id, "topic": topic} # type: openml._api_calls.DATA_TYPE
974976
result_xml = openml._api_calls._perform_api_call("data/topicdelete", "post", data=form_data)
975977
result = xmltodict.parse(result_xml)
976978
data_id = result["oml:data_topic"]["oml:id"]

openml/setups/functions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def setup_exists(flow) -> int:
4949

5050
openml_param_settings = flow.extension.obtain_parameter_values(flow)
5151
description = xmltodict.unparse(_to_dict(flow.flow_id, openml_param_settings), pretty=True)
52-
file_elements = {"description": ("description.arff", description)}
52+
file_elements = {
53+
"description": ("description.arff", description)
54+
} # type: openml._api_calls.FILE_ELEMENTS_TYPE
5355
result = openml._api_calls._perform_api_call(
5456
"/setup/exists/", "post", file_elements=file_elements
5557
)

openml/study/functions.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def update_study_status(study_id: int, status: str) -> None:
277277
legal_status = {"active", "deactivated"}
278278
if status not in legal_status:
279279
raise ValueError("Illegal status value. " "Legal values: %s" % legal_status)
280-
data = {"study_id": study_id, "status": status}
280+
data = {"study_id": study_id, "status": status} # type: openml._api_calls.DATA_TYPE
281281
result_xml = openml._api_calls._perform_api_call("study/status/update", "post", data=data)
282282
result = xmltodict.parse(result_xml)
283283
server_study_id = result["oml:study_status_update"]["oml:id"]
@@ -357,8 +357,10 @@ def attach_to_study(study_id: int, run_ids: List[int]) -> int:
357357

358358
# Interestingly, there's no need to tell the server about the entity type, it knows by itself
359359
uri = "study/%d/attach" % study_id
360-
post_variables = {"ids": ",".join(str(x) for x in run_ids)}
361-
result_xml = openml._api_calls._perform_api_call(uri, "post", post_variables)
360+
post_variables = {"ids": ",".join(str(x) for x in run_ids)} # type: openml._api_calls.DATA_TYPE
361+
result_xml = openml._api_calls._perform_api_call(
362+
call=uri, request_method="post", data=post_variables
363+
)
362364
result = xmltodict.parse(result_xml)["oml:study_attach"]
363365
return int(result["oml:linked_entities"])
364366

@@ -400,8 +402,10 @@ def detach_from_study(study_id: int, run_ids: List[int]) -> int:
400402

401403
# Interestingly, there's no need to tell the server about the entity type, it knows by itself
402404
uri = "study/%d/detach" % study_id
403-
post_variables = {"ids": ",".join(str(x) for x in run_ids)}
404-
result_xml = openml._api_calls._perform_api_call(uri, "post", post_variables)
405+
post_variables = {"ids": ",".join(str(x) for x in run_ids)} # type: openml._api_calls.DATA_TYPE
406+
result_xml = openml._api_calls._perform_api_call(
407+
call=uri, request_method="post", data=post_variables
408+
)
405409
result = xmltodict.parse(result_xml)["oml:study_detach"]
406410
return int(result["oml:linked_entities"])
407411

0 commit comments

Comments
 (0)