|
|
@@ -0,0 +1,152 @@
|
|
|
+import time
|
|
|
+import json
|
|
|
+from dataclasses import dataclass, field
|
|
|
+from typing import Dict, List, Optional, Tuple
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class PassRecord:
|
|
|
+ pass_name: str
|
|
|
+ durations: List[float] = field(default_factory=list)
|
|
|
+
|
|
|
+ def add_duration(self, duration: float):
|
|
|
+ self.durations.append(duration)
|
|
|
+
|
|
|
+ def stats(self):
|
|
|
+ if not self.durations:
|
|
|
+ return {"count": 0, "avg": 0, "total": 0, "max": 0, "min": 0}
|
|
|
+ return {
|
|
|
+ "count": len(self.durations),
|
|
|
+ "avg": sum(self.durations) / len(self.durations),
|
|
|
+ "total": sum(self.durations),
|
|
|
+ "max": max(self.durations),
|
|
|
+ "min": min(self.durations),
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class NodeRecord:
|
|
|
+ node_type: str
|
|
|
+ node_id: Tuple[Optional[str]] = field(default_factory=tuple)
|
|
|
+ passes: Dict[str, PassRecord] = field(default_factory=dict)
|
|
|
+
|
|
|
+ def record_pass(self, pass_name: str, duration: float):
|
|
|
+ if pass_name not in self.passes:
|
|
|
+ self.passes[pass_name] = PassRecord(pass_name)
|
|
|
+ self.passes[pass_name].add_duration(duration)
|
|
|
+
|
|
|
+ def stats(self):
|
|
|
+ return {pname: prec.stats() for pname, prec in self.passes.items()}
|
|
|
+
|
|
|
+
|
|
|
+@dataclass
|
|
|
+class ProfileSession:
|
|
|
+ name: str
|
|
|
+ nodes: Dict[Tuple[Optional[str]], NodeRecord] = field(default_factory=dict)
|
|
|
+
|
|
|
+ def record(self, node_id: Tuple[Optional[str]], node_type: str, pass_name: str, duration: float):
|
|
|
+ if node_id not in self.nodes:
|
|
|
+ self.nodes[node_id] = NodeRecord(node_id, node_type)
|
|
|
+ self.nodes[node_id].record_pass(pass_name, duration)
|
|
|
+
|
|
|
+ def stats(self):
|
|
|
+ return {nid: node.stats() for nid, node in self.nodes.items()}
|
|
|
+
|
|
|
+
|
|
|
+class NodeProfiler:
|
|
|
+ def __init__(self):
|
|
|
+ self.sessions: Dict[str, ProfileSession] = {}
|
|
|
+
|
|
|
+ def new_session(self, name: str):
|
|
|
+ self.sessions[name] = ProfileSession(name)
|
|
|
+ self._current = self.sessions[name]
|
|
|
+
|
|
|
+ def record(self, node, pass_name, *args, **kwargs):
|
|
|
+ """Profile execution of a function (e.g., node pass)."""
|
|
|
+ node_id = node.signature
|
|
|
+ node_type = node.__class__.__name__
|
|
|
+ func = getattr(node, pass_name)
|
|
|
+ start = time.perf_counter()
|
|
|
+ result = func(*args, **kwargs)
|
|
|
+ duration = time.perf_counter() - start
|
|
|
+ self._current.record(node_type, node_id, pass_name, duration)
|
|
|
+ return result
|
|
|
+
|
|
|
+ def get_stats(self, session_name: Optional[str]):
|
|
|
+ if session_name is None:
|
|
|
+ session_name = list(self.sessions.keys())[-1] # get the lastest one
|
|
|
+ return self.sessions[session_name].stats()
|
|
|
+
|
|
|
+ def to_json(self):
|
|
|
+ def encode(obj):
|
|
|
+ if hasattr(obj, "__dict__"):
|
|
|
+ return obj.__dict__
|
|
|
+ return obj
|
|
|
+ return json.dumps(self.sessions, default=encode, indent=2)
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def from_json(cls, data: str):
|
|
|
+ raw = json.loads(data)
|
|
|
+ profiler = cls()
|
|
|
+ for sname, sdata in raw.items():
|
|
|
+ session = ProfileSession(name=sdata["name"])
|
|
|
+ for nid, ndata in sdata["nodes"].items():
|
|
|
+ node = NodeRecord(node_id=ndata["node_id"], node_type=ndata["node_type"])
|
|
|
+ for pname, pdata in ndata["passes"].items():
|
|
|
+ record = PassRecord(pname, pdata["durations"])
|
|
|
+ node.passes[pname] = record
|
|
|
+ session.nodes[nid] = node
|
|
|
+ profiler.sessions[sname] = session
|
|
|
+ return profiler
|
|
|
+
|
|
|
+
|
|
|
+def summarize_profile(session, pass_name = "", sort_key='total'):
|
|
|
+ """
|
|
|
+ Print a formatted summary from a ProfileSession object.
|
|
|
+ Columns: Node Type | Count | Total | Average | Min | Max
|
|
|
+ """
|
|
|
+ from ..base_definitions import FLOAT_EPSILON
|
|
|
+ from collections import defaultdict
|
|
|
+
|
|
|
+ summary = defaultdict(lambda: {"count": 0, "total": 0.0, "min": float("inf"), "max": 0.0})
|
|
|
+
|
|
|
+ for node in session.nodes.values():
|
|
|
+ node_type = node.node_type
|
|
|
+ prec = node.passes.get(pass_name)
|
|
|
+ if prec: # not every node does every pass.
|
|
|
+ for d in prec.durations:
|
|
|
+ s = summary[node_type]
|
|
|
+ s["count"] += 1
|
|
|
+ s["total"] += d
|
|
|
+ s["min"] = min(s["min"], d)
|
|
|
+ s["max"] = max(s["max"], d)
|
|
|
+
|
|
|
+ # Prepare and print the table
|
|
|
+ print(f"Pass: {pass_name}\n")
|
|
|
+ header = f"{'Node Type':<40}{'Count':>8}{'Total(s)':>12}{'Avg(s)':>12}{'Min(s)':>12}{'Max(s)':>12}"
|
|
|
+ print(header)
|
|
|
+ print("-" * len(header))
|
|
|
+
|
|
|
+ sorted_items = list(summary.items())
|
|
|
+ sorted_items.sort(key = lambda a : -a[1][sort_key]) # for some reason this reverse sorts if I don't negate it?
|
|
|
+
|
|
|
+ accumulated_count = 0; accumulated_total = 0; overall_avg = 0
|
|
|
+ overall_min = float("inf"); overall_max = -1
|
|
|
+
|
|
|
+ for node_type, data in sorted_items:
|
|
|
+ count = data["count"]; total = data["total"]
|
|
|
+ accumulated_total += total # always accumulate this even for noop
|
|
|
+ avg = total / count if count else 0
|
|
|
+ if avg < 0.0000033: continue # try to avoid printing it if it is a no-op or not significant
|
|
|
+ accumulated_count += count
|
|
|
+ overall_min = min(overall_min, data['min']); overall_max = max(overall_max, data['max'])
|
|
|
+ print(f"{node_type:<40}{count:>8}{total:>12.4f}{avg:>12.4f}{data['min']:>12.4f}{data['max']:>12.4f}")
|
|
|
+
|
|
|
+ if accumulated_count != 0: # avoid zero-division. The average is not meaningful in this case, anyway.
|
|
|
+ overall_avg = accumulated_total/accumulated_count
|
|
|
+
|
|
|
+ footer = f"{f'Summary({pass_name}): ':<40}{int(accumulated_count):>8}{accumulated_total:>12.4f}{overall_avg:>12.4f}{overall_min:>12.4f}{overall_max:>12.4f}"
|
|
|
+ print("-" * len(footer))
|
|
|
+ print(footer)
|
|
|
+ print ("\n")
|