File size: 9,432 Bytes
359fa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""Test that progress updates are properly isolated between WebSocket clients."""

import json
import pytest
import time
import threading
import uuid
import websocket
from typing import List, Dict, Any
from comfy_execution.graph_utils import GraphBuilder
from tests.execution.test_execution import ComfyClient


class ProgressTracker:
    """Tracks progress messages received by a WebSocket client."""

    def __init__(self, client_id: str):
        self.client_id = client_id
        self.progress_messages: List[Dict[str, Any]] = []
        self.lock = threading.Lock()

    def add_message(self, message: Dict[str, Any]):
        """Thread-safe addition of progress messages."""
        with self.lock:
            self.progress_messages.append(message)

    def get_messages_for_prompt(self, prompt_id: str) -> List[Dict[str, Any]]:
        """Get all progress messages for a specific prompt_id."""
        with self.lock:
            return [
                msg for msg in self.progress_messages
                if msg.get('data', {}).get('prompt_id') == prompt_id
            ]

    def has_cross_contamination(self, own_prompt_id: str) -> bool:
        """Check if this client received progress for other prompts."""
        with self.lock:
            for msg in self.progress_messages:
                msg_prompt_id = msg.get('data', {}).get('prompt_id')
                if msg_prompt_id and msg_prompt_id != own_prompt_id:
                    return True
            return False


class IsolatedClient(ComfyClient):
    """Extended ComfyClient that tracks all WebSocket messages."""

    def __init__(self):
        super().__init__()
        self.progress_tracker = None
        self.all_messages: List[Dict[str, Any]] = []

    def connect(self, listen='127.0.0.1', port=8188, client_id=None):
        """Connect with a specific client_id and set up message tracking."""
        if client_id is None:
            client_id = str(uuid.uuid4())
        super().connect(listen, port, client_id)
        self.progress_tracker = ProgressTracker(client_id)

    def listen_for_messages(self, duration: float = 5.0):
        """Listen for WebSocket messages for a specified duration."""
        end_time = time.time() + duration
        self.ws.settimeout(0.5)  # Non-blocking with timeout

        while time.time() < end_time:
            try:
                out = self.ws.recv()
                if isinstance(out, str):
                    message = json.loads(out)
                    self.all_messages.append(message)

                    # Track progress_state messages
                    if message.get('type') == 'progress_state':
                        self.progress_tracker.add_message(message)
            except websocket.WebSocketTimeoutException:
                continue
            except Exception:
                # Log error silently in test context
                break


@pytest.mark.execution
class TestProgressIsolation:
    """Test suite for verifying progress update isolation between clients."""

    @pytest.fixture(scope="class", autouse=True)
    def _server(self, args_pytest):
        """Start the ComfyUI server for testing."""
        import subprocess
        pargs = [
            'python', 'main.py',
            '--output-directory', args_pytest["output_dir"],
            '--listen', args_pytest["listen"],
            '--port', str(args_pytest["port"]),
            '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml',
            '--cpu',
        ]
        p = subprocess.Popen(pargs)
        yield
        p.kill()

    def start_client_with_retry(self, listen: str, port: int, client_id: str = None):
        """Start client with connection retries."""
        client = IsolatedClient()
        # Connect to server (with retries)
        n_tries = 5
        for i in range(n_tries):
            time.sleep(4)
            try:
                client.connect(listen, port, client_id)
                return client
            except ConnectionRefusedError as e:
                print(e)  # noqa: T201
                print(f"({i+1}/{n_tries}) Retrying...")  # noqa: T201
        raise ConnectionRefusedError(f"Failed to connect after {n_tries} attempts")

    def test_progress_isolation_between_clients(self, args_pytest):
        """Test that progress updates are isolated between different clients."""
        listen = args_pytest["listen"]
        port = args_pytest["port"]

        # Create two separate clients with unique IDs
        client_a_id = "client_a_" + str(uuid.uuid4())
        client_b_id = "client_b_" + str(uuid.uuid4())

        try:
            # Connect both clients with retries
            client_a = self.start_client_with_retry(listen, port, client_a_id)
            client_b = self.start_client_with_retry(listen, port, client_b_id)

            # Create simple workflows for both clients
            graph_a = GraphBuilder(prefix="client_a")
            image_a = graph_a.node("StubImage", content="BLACK", height=256, width=256, batch_size=1)
            graph_a.node("PreviewImage", images=image_a.out(0))

            graph_b = GraphBuilder(prefix="client_b")
            image_b = graph_b.node("StubImage", content="WHITE", height=256, width=256, batch_size=1)
            graph_b.node("PreviewImage", images=image_b.out(0))

            # Submit workflows from both clients
            prompt_a = graph_a.finalize()
            prompt_b = graph_b.finalize()

            response_a = client_a.queue_prompt(prompt_a)
            prompt_id_a = response_a['prompt_id']

            response_b = client_b.queue_prompt(prompt_b)
            prompt_id_b = response_b['prompt_id']

            # Start threads to listen for messages on both clients
            def listen_client_a():
                client_a.listen_for_messages(duration=10.0)

            def listen_client_b():
                client_b.listen_for_messages(duration=10.0)

            thread_a = threading.Thread(target=listen_client_a)
            thread_b = threading.Thread(target=listen_client_b)

            thread_a.start()
            thread_b.start()

            # Wait for threads to complete
            thread_a.join()
            thread_b.join()

            # Verify isolation
            # Client A should only receive progress for prompt_id_a
            assert not client_a.progress_tracker.has_cross_contamination(prompt_id_a), \
                f"Client A received progress updates for other clients' workflows. " \
                f"Expected only {prompt_id_a}, but got messages for multiple prompts."

            # Client B should only receive progress for prompt_id_b
            assert not client_b.progress_tracker.has_cross_contamination(prompt_id_b), \
                f"Client B received progress updates for other clients' workflows. " \
                f"Expected only {prompt_id_b}, but got messages for multiple prompts."

            # Verify each client received their own progress updates
            client_a_messages = client_a.progress_tracker.get_messages_for_prompt(prompt_id_a)
            client_b_messages = client_b.progress_tracker.get_messages_for_prompt(prompt_id_b)

            assert len(client_a_messages) > 0, \
                "Client A did not receive any progress updates for its own workflow"
            assert len(client_b_messages) > 0, \
                "Client B did not receive any progress updates for its own workflow"

            # Ensure no cross-contamination
            client_a_other = client_a.progress_tracker.get_messages_for_prompt(prompt_id_b)
            client_b_other = client_b.progress_tracker.get_messages_for_prompt(prompt_id_a)

            assert len(client_a_other) == 0, \
                f"Client A incorrectly received {len(client_a_other)} progress updates for Client B's workflow"
            assert len(client_b_other) == 0, \
                f"Client B incorrectly received {len(client_b_other)} progress updates for Client A's workflow"

        finally:
            # Clean up connections
            if hasattr(client_a, 'ws'):
                client_a.ws.close()
            if hasattr(client_b, 'ws'):
                client_b.ws.close()

    def test_progress_with_missing_client_id(self, args_pytest):
        """Test that progress updates handle missing client_id gracefully."""
        listen = args_pytest["listen"]
        port = args_pytest["port"]

        try:
            # Connect client with retries
            client = self.start_client_with_retry(listen, port)

            # Create a simple workflow
            graph = GraphBuilder(prefix="test_missing_id")
            image = graph.node("StubImage", content="BLACK", height=128, width=128, batch_size=1)
            graph.node("PreviewImage", images=image.out(0))

            # Submit workflow
            prompt = graph.finalize()
            response = client.queue_prompt(prompt)
            prompt_id = response['prompt_id']

            # Listen for messages
            client.listen_for_messages(duration=5.0)

            # Should still receive progress updates for own workflow
            messages = client.progress_tracker.get_messages_for_prompt(prompt_id)
            assert len(messages) > 0, \
                "Client did not receive progress updates even though it initiated the workflow"

        finally:
            if hasattr(client, 'ws'):
                client.ws.close()