Skip to content

Commit 9322253

Browse files
committed
Add support for partial execution in backend
When a prompt is submitted, it can optionally include `partial_execution_targets` as a list of ids. If it does, rather than adding all outputs to the execution list, we add only those in the list.
1 parent da9dab7 commit 9322253

File tree

5 files changed

+233
-19
lines changed

5 files changed

+233
-19
lines changed

execution.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import time
88
import traceback
99
from enum import Enum
10-
from typing import List, Literal, NamedTuple, Optional
10+
from typing import List, Literal, NamedTuple, Optional, Union
1111
import asyncio
1212

1313
import torch
@@ -891,7 +891,7 @@ def full_type_name(klass):
891891
return klass.__qualname__
892892
return module + '.' + klass.__qualname__
893893

894-
async def validate_prompt(prompt_id, prompt):
894+
async def validate_prompt(prompt_id, prompt, partial_execution_list: Union[list[str], None]):
895895
outputs = set()
896896
for x in prompt:
897897
if 'class_type' not in prompt[x]:
@@ -915,7 +915,8 @@ async def validate_prompt(prompt_id, prompt):
915915
return (False, error, [], {})
916916

917917
if hasattr(class_, 'OUTPUT_NODE') and class_.OUTPUT_NODE is True:
918-
outputs.add(x)
918+
if partial_execution_list is None or x in partial_execution_list:
919+
outputs.add(x)
919920

920921
if len(outputs) == 0:
921922
error = {

server.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,12 @@ async def post_prompt(request):
681681
if "prompt" in json_data:
682682
prompt = json_data["prompt"]
683683
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
684-
valid = await execution.validate_prompt(prompt_id, prompt)
684+
685+
partial_execution_targets = None
686+
if "partial_execution_targets" in json_data:
687+
partial_execution_targets = json_data["partial_execution_targets"]
688+
689+
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
685690
extra_data = {}
686691
if "extra_data" in json_data:
687692
extra_data = json_data["extra_data"]

tests/inference/test_async_nodes.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from pytest import fixture
99
from comfy_execution.graph_utils import GraphBuilder
10-
from tests.inference.test_execution import ComfyClient
10+
from tests.inference.test_execution import ComfyClient, run_warmup
1111

1212

1313
@pytest.mark.execution
@@ -24,6 +24,7 @@ def _server(self, args_pytest, request):
2424
'--listen', args_pytest["listen"],
2525
'--port', str(args_pytest["port"]),
2626
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
27+
'--cpu',
2728
]
2829
use_lru, lru_size = request.param
2930
if use_lru:
@@ -82,6 +83,9 @@ def test_basic_async_execution(self, client: ComfyClient, builder: GraphBuilder)
8283

8384
def test_multiple_async_parallel_execution(self, client: ComfyClient, builder: GraphBuilder):
8485
"""Test that multiple async nodes execute in parallel."""
86+
# Warmup execution to ensure server is fully initialized
87+
run_warmup(client)
88+
8589
g = builder
8690
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
8791

@@ -148,6 +152,9 @@ def test_async_validate_inputs(self, client: ComfyClient, builder: GraphBuilder)
148152

149153
def test_async_lazy_evaluation(self, client: ComfyClient, builder: GraphBuilder):
150154
"""Test async nodes with lazy evaluation."""
155+
# Warmup execution to ensure server is fully initialized
156+
run_warmup(client, prefix="warmup_lazy")
157+
151158
g = builder
152159
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
153160
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
@@ -305,6 +312,9 @@ def test_async_with_execution_blocker(self, client: ComfyClient, builder: GraphB
305312

306313
def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder):
307314
"""Test that async nodes are properly cached."""
315+
# Warmup execution to ensure server is fully initialized
316+
run_warmup(client, prefix="warmup_cache")
317+
308318
g = builder
309319
image = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
310320
sleep_node = g.node("TestSleep", value=image.out(0), seconds=0.2)
@@ -324,6 +334,9 @@ def test_async_caching_behavior(self, client: ComfyClient, builder: GraphBuilder
324334

325335
def test_async_with_dynamic_prompts(self, client: ComfyClient, builder: GraphBuilder):
326336
"""Test async nodes within dynamically generated prompts."""
337+
# Warmup execution to ensure server is fully initialized
338+
run_warmup(client, prefix="warmup_dynamic")
339+
327340
g = builder
328341
image1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
329342
image2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)

0 commit comments

Comments
 (0)