Skip to content

Commit 4d40f6f

Browse files
authored
Fix bug in MPITaskScheduler where tasks with identical priorities fail with TypeError (#3794)
# Description The `MPITaskScheduler` uses Python's PriorityQueue to prioritize tasks based on the number of nodes requested. When items with identical priorities are are submitted to the PriorityQueue, they attempt to sort based on the task dict which fails with TypeError unhashable type: dict. This PR adds a new `PrioritizedTask` dataclass that sets the task element to `field(compare=False)`. I'm splitting changes in #3783 to keep the PR concise. This is split 1 of 3. # Changed Behaviour Fixes the bug described above. ## Type of change Choose which options apply, and delete the ones which do not apply. - Bug fix
1 parent 07a4efb commit 4d40f6f

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

parsl/executors/high_throughput/mpi_resource_management.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pickle
55
import queue
66
import subprocess
7+
from dataclasses import dataclass, field
78
from enum import Enum
89
from typing import Dict, List, Optional
910

@@ -69,6 +70,14 @@ def __str__(self):
6970
return f"MPINodesUnavailable(requested={self.requested} available={self.available})"
7071

7172

73+
@dataclass(order=True)
74+
class PrioritizedTask:
75+
# Comparing dict will fail since they are unhashable
76+
# This dataclass limits comparison to the priority field
77+
priority: int
78+
task: Dict = field(compare=False)
79+
80+
7281
class TaskScheduler:
7382
"""Default TaskScheduler that does no taskscheduling
7483
@@ -111,7 +120,7 @@ def __init__(
111120
super().__init__(pending_task_q, pending_result_q)
112121
self.scheduler = identify_scheduler()
113122
# PriorityQueue is threadsafe
114-
self._backlog_queue: queue.PriorityQueue = queue.PriorityQueue()
123+
self._backlog_queue: queue.PriorityQueue[PrioritizedTask] = queue.PriorityQueue()
115124
self._map_tasks_to_nodes: Dict[str, List[str]] = {}
116125
self.available_nodes = get_nodes_in_batchjob(self.scheduler)
117126
self._free_node_counter = SpawnContext.Value("i", len(self.available_nodes))
@@ -169,7 +178,7 @@ def put_task(self, task_package: dict):
169178
allocated_nodes = self._get_nodes(nodes_needed)
170179
except MPINodesUnavailable:
171180
logger.info(f"Not enough resources, placing task {tid} into backlog")
172-
self._backlog_queue.put((nodes_needed, task_package))
181+
self._backlog_queue.put(PrioritizedTask(nodes_needed, task_package))
173182
return
174183
else:
175184
resource_spec["MPI_NODELIST"] = ",".join(allocated_nodes)
@@ -183,8 +192,8 @@ def put_task(self, task_package: dict):
183192
def _schedule_backlog_tasks(self):
184193
"""Attempt to schedule backlogged tasks"""
185194
try:
186-
_nodes_requested, task_package = self._backlog_queue.get(block=False)
187-
self.put_task(task_package)
195+
prioritized_task = self._backlog_queue.get(block=False)
196+
self.put_task(prioritized_task.task)
188197
except queue.Empty:
189198
return
190199
else:

parsl/tests/test_mpi_apps/test_mpi_scheduler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,28 @@ def test_MPISched_contention():
161161
assert task_on_worker_side['task_id'] == 2
162162
_, _, _, resource_spec = unpack_res_spec_apply_message(task_on_worker_side['buffer'])
163163
assert len(resource_spec['MPI_NODELIST'].split(',')) == 8
164+
165+
166+
@pytest.mark.local
167+
def test_hashable_backlog_queue():
168+
"""Run multiple large tasks that to force entry into backlog_queue
169+
where queue.PriorityQueue expects hashability/comparability
170+
"""
171+
172+
task_q, result_q = SpawnContext.Queue(), SpawnContext.Queue()
173+
scheduler = MPITaskScheduler(task_q, result_q)
174+
175+
assert scheduler.available_nodes
176+
assert len(scheduler.available_nodes) == 8
177+
178+
assert scheduler._free_node_counter.value == 8
179+
180+
for i in range(3):
181+
mock_task_buffer = pack_res_spec_apply_message("func", "args", "kwargs",
182+
resource_specification={
183+
"num_nodes": 8,
184+
"ranks_per_node": 2
185+
})
186+
task_package = {"task_id": i, "buffer": mock_task_buffer}
187+
scheduler.put_task(task_package)
188+
assert scheduler._backlog_queue.qsize() == 2, "Expected 2 backlogged tasks"

0 commit comments

Comments
 (0)