Kaynağa Gözat

Add profiler to help optimize nodes

most of the code for this patch was written by chatGPT
i am not a huge fan of LLMs but I see it as a tool to help me
save dev time and deliver better results, faster for my
beloved users
Joseph Brandenburg 3 ay önce
ebeveyn
işleme
b4ad5bfeff
4 değiştirilmiş dosya ile 181 ekleme ve 13 silme
  1. 2 2
      base_definitions.py
  2. 152 0
      dev_helpers/profile_nodes.py
  3. 1 1
      ops_nodegroup.py
  4. 26 10
      readtree.py

+ 2 - 2
base_definitions.py

@@ -160,7 +160,7 @@ class MantisTree(NodeTree):
         #    - Non-hierarchy links should be ignored in the circle-check and so the links should be marked valid in such a circle
         #    - hierarchy-links should be marked invalid and prevent the tree from executing.
 
-    def execute_tree(self,context, error_popups = False):
+    def execute_tree(self,context, error_popups = False, profile=False):
         self.prevent_next_exec = False
         if not self.hash:
             return
@@ -171,7 +171,7 @@ class MantisTree(NodeTree):
         from . import readtree
         try:
             context.scene.render.use_lock_interface = True
-            readtree.execute_tree(self.parsed_tree, self, context, error_popups)
+            readtree.execute_tree(self.parsed_tree, self, context, error_popups, profile=profile)
         except RecursionError as e:
             prRed("Recursion error while parsing tree.")
         finally:

+ 152 - 0
dev_helpers/profile_nodes.py

@@ -0,0 +1,152 @@
+import time
+import json
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple
+
+
+
+@dataclass
+class PassRecord:
+    pass_name: str
+    durations: List[float] = field(default_factory=list)
+
+    def add_duration(self, duration: float):
+        self.durations.append(duration)
+
+    def stats(self):
+        if not self.durations:
+            return {"count": 0, "avg": 0, "total": 0, "max": 0, "min": 0}
+        return {
+            "count": len(self.durations),
+            "avg": sum(self.durations) / len(self.durations),
+            "total": sum(self.durations),
+            "max": max(self.durations),
+            "min": min(self.durations),
+        }
+
+
+@dataclass
+class NodeRecord:
+    node_type: str
+    node_id: Tuple[Optional[str]] = field(default_factory=tuple)
+    passes: Dict[str, PassRecord] = field(default_factory=dict)
+
+    def record_pass(self, pass_name: str, duration: float):
+        if pass_name not in self.passes:
+            self.passes[pass_name] = PassRecord(pass_name)
+        self.passes[pass_name].add_duration(duration)
+
+    def stats(self):
+        return {pname: prec.stats() for pname, prec in self.passes.items()}
+
+
+@dataclass
+class ProfileSession:
+    name: str
+    nodes: Dict[Tuple[Optional[str]], NodeRecord] = field(default_factory=dict)
+
+    def record(self, node_id: Tuple[Optional[str]], node_type: str, pass_name: str, duration: float):
+        if node_id not in self.nodes:
+            self.nodes[node_id] = NodeRecord(node_id, node_type)
+        self.nodes[node_id].record_pass(pass_name, duration)
+
+    def stats(self):
+        return {nid: node.stats() for nid, node in self.nodes.items()}
+
+
+class NodeProfiler:
+    def __init__(self):
+        self.sessions: Dict[str, ProfileSession] = {}
+
+    def new_session(self, name: str):
+        self.sessions[name] = ProfileSession(name)
+        self._current = self.sessions[name]
+
+    def record(self, node, pass_name, *args, **kwargs):
+        """Profile execution of a function (e.g., node pass)."""
+        node_id = node.signature
+        node_type = node.__class__.__name__
+        func = getattr(node, pass_name)
+        start = time.perf_counter()
+        result = func(*args, **kwargs)
+        duration = time.perf_counter() - start
+        self._current.record(node_type, node_id, pass_name, duration)
+        return result
+
+    def get_stats(self, session_name: Optional[str]):
+        if session_name is None:
+            session_name = list(self.sessions.keys())[-1] # get the lastest one
+        return self.sessions[session_name].stats()
+
+    def to_json(self):
+        def encode(obj):
+            if hasattr(obj, "__dict__"):
+                return obj.__dict__
+            return obj
+        return json.dumps(self.sessions, default=encode, indent=2)
+
+    @classmethod
+    def from_json(cls, data: str):
+        raw = json.loads(data)
+        profiler = cls()
+        for sname, sdata in raw.items():
+            session = ProfileSession(name=sdata["name"])
+            for nid, ndata in sdata["nodes"].items():
+                node = NodeRecord(node_id=ndata["node_id"], node_type=ndata["node_type"])
+                for pname, pdata in ndata["passes"].items():
+                    record = PassRecord(pname, pdata["durations"])
+                    node.passes[pname] = record
+                session.nodes[nid] = node
+            profiler.sessions[sname] = session
+        return profiler
+
+
+def summarize_profile(session, pass_name = "", sort_key='total'):
+    """
+    Print a formatted summary from a ProfileSession object.
+    Columns: Node Type | Count | Total | Average | Min | Max
+    """
+    from ..base_definitions import FLOAT_EPSILON
+    from collections import defaultdict
+
+    summary = defaultdict(lambda: {"count": 0, "total": 0.0, "min": float("inf"), "max": 0.0})
+
+    for node in session.nodes.values():
+        node_type = node.node_type
+        prec = node.passes.get(pass_name)
+        if prec: # not every node does every pass.
+            for d in prec.durations:
+                s = summary[node_type]
+                s["count"] += 1
+                s["total"] += d
+                s["min"] = min(s["min"], d)
+                s["max"] = max(s["max"], d)
+
+    # Prepare and print the table
+    print(f"Pass: {pass_name}\n")
+    header =  f"{'Node Type':<40}{'Count':>8}{'Total(s)':>12}{'Avg(s)':>12}{'Min(s)':>12}{'Max(s)':>12}"
+    print(header)
+    print("-" * len(header))
+
+    sorted_items = list(summary.items())
+    sorted_items.sort(key = lambda a : -a[1][sort_key]) # for some reason this reverse sorts if I don't negate it?
+
+    accumulated_count = 0; accumulated_total = 0; overall_avg = 0
+    overall_min = float("inf"); overall_max = -1
+
+    for node_type, data in sorted_items:
+        count = data["count"]; total = data["total"]
+        accumulated_total += total # always accumulate this even for noop
+        avg = total / count if count else 0
+        if avg < 0.0000033: continue # try to avoid printing it if it is a no-op or not significant
+        accumulated_count += count
+        overall_min = min(overall_min, data['min']); overall_max = max(overall_max, data['max'])
+        print(f"{node_type:<40}{count:>8}{total:>12.4f}{avg:>12.4f}{data['min']:>12.4f}{data['max']:>12.4f}")
+    
+    if accumulated_count != 0: # avoid zero-division. The average is not meaningful in this case, anyway.
+        overall_avg = accumulated_total/accumulated_count
+
+    footer =  f"{f'Summary({pass_name}): ':<40}{int(accumulated_count):>8}{accumulated_total:>12.4f}{overall_avg:>12.4f}{overall_min:>12.4f}{overall_max:>12.4f}"
+    print("-" * len(footer))
+    print(footer)
+    print ("\n")

+ 1 - 1
ops_nodegroup.py

@@ -291,7 +291,7 @@ class ExecuteNodeTree(Operator):
             from pstats import SortKey
             with cProfile.Profile() as pr:
                 tree.update_tree(context, error_popups = pass_error)
-                tree.execute_tree(context, error_popups = pass_error)
+                tree.execute_tree(context, error_popups = pass_error, profile=True)
                 # from the Python docs at https://docs.python.org/3/library/profile.html#module-cProfile
                 s = io.StringIO()
                 sortby = SortKey.TIME

+ 26 - 10
readtree.py

@@ -1,8 +1,6 @@
 from .utilities import prRed, prGreen, prPurple, prWhite, prOrange, \
                         wrapRed, wrapGreen, wrapPurple, wrapWhite, wrapOrange
 
-    
-
 
 def grp_node_reroute_common(mantis_node, to_mantis_node, all_mantis_nodes):
     # we need to do this: go  to the to-node
@@ -536,7 +534,7 @@ def sort_execution(nodes, xForm_pass):
                         xForm_pass.appendleft(conn)
     return sorted_nodes, execution_failed
 
-def execute_tree(nodes, base_tree, context, error_popups = False):
+def execute_tree(nodes, base_tree, context, error_popups = False, profile=False):
     assert nodes is not None, "Failed to parse tree."
     assert len(nodes) > 0, "No parsed nodes for execution."\
                            " Mantis probably failed to parse the tree."
@@ -559,14 +557,22 @@ def execute_tree(nodes, base_tree, context, error_popups = False):
     switch_me = [] # switch the mode on these objects
     active = None # only need it for switching modes
     select_me = []
+    profiler = None
+    if profile:
+        from .dev_helpers.profile_nodes import NodeProfiler
+        profiler = NodeProfiler()
+        profiler.new_session(mContext.execution_id)
+        sort_key = 'total'
     try:
         sorted_nodes, execution_failed = sort_execution(nodes, xForm_pass)
         for n in sorted_nodes:
             try:
                 if not n.prepared:
-                    n.bPrepare(context)
+                    if profiler: profiler.record(n, "bPrepare", context)
+                    else: n.bPrepare(context)
                 if not n.executed:
-                    n.bTransformPass(context)
+                    if profiler: profiler.record(n, "bTransformPass", context)
+                    else: n.bTransformPass(context)
                 if (n.__class__.__name__ == "xFormArmature" ):
                     ob = n.bGetObject()
                     switch_me.append(ob)
@@ -585,10 +591,11 @@ def execute_tree(nodes, base_tree, context, error_popups = False):
 
         for n in sorted_nodes:
             try:
-                if not n.prepared:
-                    n.bPrepare(context)
+                if profiler: profiler.record(n, "bPrepare", context)
+                else: n.bPrepare(context)
                 if not n.executed:
-                    n.bRelationshipPass(context)
+                    if profiler: profiler.record(n, "bRelationshipPass", context)
+                    else: n.bRelationshipPass(context)
             except Exception as e:
                 e = execution_error_cleanup(n, e, show_error=error_popups)
                 if error_popups == False:
@@ -606,7 +613,8 @@ def execute_tree(nodes, base_tree, context, error_popups = False):
 
         for n in sorted_nodes:
             try:
-                n.bFinalize(context)
+                if profiler: profiler.record(n, "bFinalize", context)
+                else: n.bFinalize(context)
             except Exception as e:
                 e = execution_error_cleanup(n, e, show_error=error_popups)
                 if error_popups == False:
@@ -621,7 +629,8 @@ def execute_tree(nodes, base_tree, context, error_popups = False):
         # finally, apply modifiers and bind stuff
         for n in sorted_nodes:
             try:
-                n.bModifierApply(context)
+                if profiler: profiler.record(n, "bModifierApply", context)
+                else: n.bModifierApply(context)
             except Exception as e:
                 e = execution_error_cleanup(n, e, show_error=error_popups)
                 if error_popups == False:
@@ -634,6 +643,13 @@ def execute_tree(nodes, base_tree, context, error_popups = False):
         tot_time = (time() - start_execution_time)
         if not execution_failed:
             prGreen(f"Executed tree of {len(sorted_nodes)} nodes in {tot_time} seconds")
+        if profiler:
+            from .dev_helpers.profile_nodes import summarize_profile
+            summarize_profile(profiler._current, pass_name='bPrepare', sort_key = sort_key)
+            summarize_profile(profiler._current, pass_name='bTransformPass', sort_key = sort_key)
+            summarize_profile(profiler._current, pass_name='bRelationshipPass', sort_key = sort_key)
+            summarize_profile(profiler._current, pass_name='bFinalize', sort_key = sort_key)
+            summarize_profile(profiler._current, pass_name='bModifierApply', sort_key = sort_key)
         if (original_active):
             context.view_layer.objects.active = original_active
             original_active.select_set(True)