Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 32 additions & 7 deletions impectPy/generate_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,16 @@

allowed_codes = [
"playerName",
"squadName",
"team",
"actionType",
"action"
]

allowed_perspectives = [
"teamName",
"teamFocus"
]

# define allowed label/code combinations
combinations = {
"eventId": {"playerName": True, "team": False, "action": True, "actionType": True},
Expand Down Expand Up @@ -125,6 +130,8 @@ def generateXML(
p4Start: int,
p5Start: int,
codeTag: str,
squad: None,
perspective: None,
labels=None,
kpis=None,
labelSorting: bool = True,
Expand All @@ -151,11 +158,15 @@ def generateXML(
f"A valid integer is required, but got: '{start_time}'."
)

# handle kpis and labels defaults
# handle kpis, labels, squad and perspective defaults
if labels is None:
labels = [label["name"] for label in allowed_labels if combinations.get(label.get("name")).get(codeTag)]
if kpis is None:
kpis = [kpi["name"] for kpi in allowed_kpis]
if squad is None or perspective is None:
perspective = "teamName"
elif squad not in events["squadId"].unique():
raise ValueError(f"Provided squad ID '{squad}' not found in event data.")

# check for invalid kpis
invalid_kpis = [kpi for kpi in kpis if kpi not in [kpi["name"] for kpi in allowed_kpis]]
Expand All @@ -171,6 +182,9 @@ def generateXML(
if not codeTag in allowed_codes:
raise ValueError(f"Invalid Code: {codeTag}")

if not perspective in allowed_perspectives:
raise ValueError(f"Invalid perspective: {perspective}")

# keep only :
# - if KPI in kpis
# - if Label in labels
Expand Down Expand Up @@ -665,8 +679,19 @@ def generateXML(
# reset index
phases.reset_index(inplace=True)

# merge phase and squadName into one column to later pass into code tag
phases["teamPhase"] = phases["squadName"] + " - " + phases["phase"].str.replace("_", " ")
# Determine how to label team phases: by squadName or by role (home/away)

if perspective == "teamName":
phases["teamPhase"] = phases["squadName"] + " - " + phases["phase"].str.replace("_", " ")
elif perspective == "teamFocus":
my_squad_id = squad
phases["teamPhase"] = np.where(
phases["squadId"] == my_squad_id,
"mySquad - " + phases["phase"].str.replace("_", " "),
"opponent - " + phases["phase"].str.replace("_", " ")
)
else:
raise ValueError(f"Invalid value for perspective: {perspective}")

# get period starts

Expand Down Expand Up @@ -805,7 +830,7 @@ def get_bucket(bucket, value, zero_value, error_value):
seq_id_current = None

# If the selected code attribute is "squadName", generate XML entries from the `phases` DataFrame
if codeTag == "squadName":
if codeTag == "team":
for index, phase in phases.iterrows():
# Create a new XML instance for each team phase
instance = ET.SubElement(instances, "instance")
Expand Down Expand Up @@ -892,7 +917,7 @@ def get_bucket(bucket, value, zero_value, error_value):
seq_id_current = seq_id_new
else:
# Same logic as above, but without sequencing (i.e., one clip per row)
if codeTag == "squadName":
if codeTag == "team":
for index, phase in phases.iterrows():
instance = ET.SubElement(instances, "instance")
event_id = ET.SubElement(instance, "ID")
Expand Down Expand Up @@ -987,7 +1012,7 @@ def row(value, colors):
# call function
row(player, home_colors)

elif codeTag == "squadName":
elif codeTag == "team":
# add entries for away team phases
for phase in away_phases:
# call function
Expand Down