Skip to content

Commit fc084fd

Browse files
committed
fix: type-def finished for datasets
1 parent e23e938 commit fc084fd

File tree

7 files changed

+23
-25
lines changed

7 files changed

+23
-25
lines changed

openml/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _entity_letter(cls) -> str:
4646
return cls.__name__.lower()[len("OpenML") :][0]
4747

4848
@abstractmethod
49-
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
49+
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]:
5050
"""Collect all information to display in the __repr__ body.
5151
5252
Returns

openml/datasets/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def __init__(
118118
description: str,
119119
data_format: str = "arff",
120120
cache_format: str = "pickle",
121-
dataset_id: Optional[int] = None,
122-
version: Optional[int] = None,
121+
dataset_id: Optional[str] = None,
122+
version: Optional[str] = None,
123123
creator: Optional[str] = None,
124124
contributor: Optional[str] = None,
125125
collection_date: Optional[str] = None,
@@ -129,7 +129,7 @@ def __init__(
129129
url: Optional[str] = None,
130130
default_target_attribute: Optional[str] = None,
131131
row_id_attribute: Optional[str] = None,
132-
ignore_attribute: Optional[str] = None,
132+
ignore_attribute: Optional[Union[List[str], str]] = None,
133133
version_label: Optional[str] = None,
134134
citation: Optional[str] = None,
135135
tag: Optional[str] = None,

openml/datasets/functions.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,6 @@ def get_dataset(
473473
dataset = _create_dataset_from_description(
474474
description, features_file, qualities_file, arff_file, parquet_file, cache_format
475475
)
476-
else:
477-
dataset = None
478476
return dataset
479477

480478

@@ -994,7 +992,7 @@ def _get_dataset_description(did_cache_dir: str, dataset_id: int) -> Dict[str, s
994992

995993

996994
def _get_dataset_parquet(
997-
description: Union[Dict[str, Union[str, int]], OpenMLDataset],
995+
description: Union[Dict[str, str], OpenMLDataset],
998996
cache_directory: Optional[str] = None,
999997
download_all_files: bool = False,
1000998
) -> Optional[str]:
@@ -1025,12 +1023,12 @@ def _get_dataset_parquet(
10251023
output_filename : string, optional
10261024
Location of the Parquet file if successfully downloaded, None otherwise.
10271025
"""
1028-
if isinstance(description, dict):
1029-
url = cast(str, description.get("oml:minio_url"))
1030-
did = description.get("oml:id")
1031-
elif isinstance(description, OpenMLDataset):
1026+
if isinstance(description, OpenMLDataset):
10321027
url = cast(str, description._minio_url)
10331028
did = description.dataset_id
1029+
elif isinstance(description, dict):
1030+
url = cast(str, description.get("oml:minio_url"))
1031+
did = int(description.get("oml:id", ""))
10341032
else:
10351033
raise TypeError("`description` should be either OpenMLDataset or Dict.")
10361034

@@ -1063,7 +1061,7 @@ def _get_dataset_parquet(
10631061

10641062

10651063
def _get_dataset_arff(
1066-
description: Union[Dict[str, Union[str, int]], OpenMLDataset],
1064+
description: Union[Dict[str, str], OpenMLDataset],
10671065
cache_directory: Optional[str] = None,
10681066
) -> str:
10691067
"""Return the path to the local arff file of the dataset. If is not cached, it is downloaded.
@@ -1088,14 +1086,14 @@ def _get_dataset_arff(
10881086
output_filename : string
10891087
Location of ARFF file.
10901088
"""
1091-
if isinstance(description, dict):
1092-
md5_checksum_fixture = description.get("oml:md5_checksum")
1093-
url = description["oml:url"]
1094-
did = description.get("oml:id")
1095-
elif isinstance(description, OpenMLDataset):
1089+
if isinstance(description, OpenMLDataset):
10961090
md5_checksum_fixture = description.md5_checksum
1097-
url = description.url
1091+
url = cast(str, description.url)
10981092
did = description.dataset_id
1093+
elif isinstance(description, dict):
1094+
md5_checksum_fixture = description.get("oml:md5_checksum")
1095+
url = cast(str, description["oml:url"])
1096+
did = int(description.get("oml:id", ""))
10991097
else:
11001098
raise TypeError("`description` should be either OpenMLDataset or Dict.")
11011099

@@ -1214,8 +1212,8 @@ def _create_dataset_from_description(
12141212
Dataset object from dict and ARFF.
12151213
"""
12161214
return OpenMLDataset(
1217-
description["oml:name"],
1218-
description.get("oml:description"),
1215+
name=description["oml:name"],
1216+
description=description.get("oml:description", ""),
12191217
data_format=description["oml:format"],
12201218
dataset_id=description["oml:id"],
12211219
version=description["oml:version"],
@@ -1246,7 +1244,7 @@ def _create_dataset_from_description(
12461244
)
12471245

12481246

1249-
def _get_online_dataset_arff(dataset_id: int) -> str:
1247+
def _get_online_dataset_arff(dataset_id: int) -> Optional[str]:
12501248
"""Download the ARFF file for a given dataset id
12511249
from the OpenML website.
12521250

openml/flows/flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def extension(self):
173173
"No extension could be found for flow {}: {}".format(self.flow_id, self.name)
174174
)
175175

176-
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
176+
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]:
177177
"""Collect all information to display in the __repr__ body."""
178178
fields = {
179179
"Flow Name": self.name,

openml/runs/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _evaluation_summary(self, metric: str) -> str:
189189

190190
return "{:.4f} +- {:.4f}".format(np.mean(rep_means), np.mean(rep_stds))
191191

192-
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
192+
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]:
193193
"""Collect all information to display in the __repr__ body."""
194194
# Set up fields
195195
fields = {

openml/study/study.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def _entity_letter(cls) -> str:
9797
def id(self) -> Optional[int]:
9898
return self.study_id
9999

100-
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
100+
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]:
101101
"""Collect all information to display in the __repr__ body."""
102102
fields: Dict[str, Any] = {
103103
"Name": self.name,

openml/tasks/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _entity_letter(cls) -> str:
8080
def id(self) -> Optional[int]:
8181
return self.task_id
8282

83-
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
83+
def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str], None]]]:
8484
"""Collect all information to display in the __repr__ body."""
8585
fields: Dict[str, Any] = {
8686
"Task Type Description": "{}/tt/{}".format(

0 commit comments

Comments
 (0)