profile_nodes.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import time
  2. import json
  3. from dataclasses import dataclass, field
  4. from typing import Dict, List, Optional, Tuple
  5. @dataclass
  6. class PassRecord:
  7. pass_name: str
  8. durations: List[float] = field(default_factory=list)
  9. def add_duration(self, duration: float):
  10. self.durations.append(duration)
  11. def stats(self):
  12. if not self.durations:
  13. return {"count": 0, "avg": 0, "total": 0, "max": 0, "min": 0}
  14. return {
  15. "count": len(self.durations),
  16. "avg": sum(self.durations) / len(self.durations),
  17. "total": sum(self.durations),
  18. "max": max(self.durations),
  19. "min": min(self.durations),
  20. }
  21. @dataclass
  22. class NodeRecord:
  23. node_type: str
  24. node_id: Tuple[Optional[str]] = field(default_factory=tuple)
  25. passes: Dict[str, PassRecord] = field(default_factory=dict)
  26. def record_pass(self, pass_name: str, duration: float):
  27. if pass_name not in self.passes:
  28. self.passes[pass_name] = PassRecord(pass_name)
  29. self.passes[pass_name].add_duration(duration)
  30. def stats(self):
  31. return {pname: prec.stats() for pname, prec in self.passes.items()}
  32. @dataclass
  33. class ProfileSession:
  34. name: str
  35. nodes: Dict[Tuple[Optional[str]], NodeRecord] = field(default_factory=dict)
  36. def record(self, node_id: Tuple[Optional[str]], node_type: str, pass_name: str, duration: float):
  37. if node_id not in self.nodes:
  38. self.nodes[node_id] = NodeRecord(node_id, node_type)
  39. self.nodes[node_id].record_pass(pass_name, duration)
  40. def stats(self):
  41. return {nid: node.stats() for nid, node in self.nodes.items()}
  42. class NodeProfiler:
  43. def __init__(self):
  44. self.sessions: Dict[str, ProfileSession] = {}
  45. def new_session(self, name: str):
  46. self.sessions[name] = ProfileSession(name)
  47. self._current = self.sessions[name]
  48. def record(self, node, pass_name, *args, **kwargs):
  49. """Profile execution of a function (e.g., node pass)."""
  50. node_id = node.signature
  51. node_type = node.__class__.__name__
  52. func = getattr(node, pass_name)
  53. start = time.perf_counter()
  54. result = func(*args, **kwargs)
  55. duration = time.perf_counter() - start
  56. self._current.record(node_type, node_id, pass_name, duration)
  57. return result
  58. def get_stats(self, session_name: Optional[str]):
  59. if session_name is None:
  60. session_name = list(self.sessions.keys())[-1] # get the lastest one
  61. return self.sessions[session_name].stats()
  62. def to_json(self):
  63. def encode(obj):
  64. if hasattr(obj, "__dict__"):
  65. return obj.__dict__
  66. return obj
  67. return json.dumps(self.sessions, default=encode, indent=2)
  68. @classmethod
  69. def from_json(cls, data: str):
  70. raw = json.loads(data)
  71. profiler = cls()
  72. for sname, sdata in raw.items():
  73. session = ProfileSession(name=sdata["name"])
  74. for nid, ndata in sdata["nodes"].items():
  75. node = NodeRecord(node_id=ndata["node_id"], node_type=ndata["node_type"])
  76. for pname, pdata in ndata["passes"].items():
  77. record = PassRecord(pname, pdata["durations"])
  78. node.passes[pname] = record
  79. session.nodes[nid] = node
  80. profiler.sessions[sname] = session
  81. return profiler
  82. def summarize_profile(session, pass_name = "", sort_key='total'):
  83. """
  84. Print a formatted summary from a ProfileSession object.
  85. Columns: Node Type | Count | Total | Average | Min | Max
  86. """
  87. from ..base_definitions import FLOAT_EPSILON
  88. from collections import defaultdict
  89. summary = defaultdict(lambda: {"count": 0, "total": 0.0, "min": float("inf"), "max": 0.0})
  90. for node in session.nodes.values():
  91. node_type = node.node_type
  92. prec = node.passes.get(pass_name)
  93. if prec: # not every node does every pass.
  94. for d in prec.durations:
  95. s = summary[node_type]
  96. s["count"] += 1
  97. s["total"] += d
  98. s["min"] = min(s["min"], d)
  99. s["max"] = max(s["max"], d)
  100. # Prepare and print the table
  101. print(f"Pass: {pass_name}\n")
  102. header = f"{'Node Type':<40}{'Count':>8}{'Total(s)':>12}{'Avg(s)':>12}{'Min(s)':>12}{'Max(s)':>12}"
  103. print(header)
  104. print("-" * len(header))
  105. sorted_items = list(summary.items())
  106. sorted_items.sort(key = lambda a : -a[1][sort_key]) # for some reason this reverse sorts if I don't negate it?
  107. accumulated_count = 0; accumulated_total = 0; overall_avg = 0
  108. overall_min = float("inf"); overall_max = -1
  109. for node_type, data in sorted_items:
  110. count = data["count"]; total = data["total"]
  111. accumulated_total += total # always accumulate this even for noop
  112. avg = total / count if count else 0
  113. if avg < 0.0000033: continue # try to avoid printing it if it is a no-op or not significant
  114. accumulated_count += count
  115. overall_min = min(overall_min, data['min']); overall_max = max(overall_max, data['max'])
  116. print(f"{node_type:<40}{count:>8}{total:>12.4f}{avg:>12.4f}{data['min']:>12.4f}{data['max']:>12.4f}")
  117. if accumulated_count != 0: # avoid zero-division. The average is not meaningful in this case, anyway.
  118. overall_avg = accumulated_total/accumulated_count
  119. footer = f"{f'Summary({pass_name}): ':<40}{int(accumulated_count):>8}{accumulated_total:>12.4f}{overall_avg:>12.4f}{overall_min:>12.4f}{overall_max:>12.4f}"
  120. print("-" * len(footer))
  121. print(footer)
  122. print ("\n")