Skip to content

Commit 7868e09

Browse files
shengbao-zhengfacebook-github-bot
authored andcommitted
update commsReplay for 1.0.3-chakra.0.0.4 schema
Summary: 1.0.3-chakra.0.0.4 schema (PR #124035) logs <group_name, group_desc> as the new pg_name instead of pg uid in profiler. - group_name remains as the unique identifier, e.g. “0”, "1" - group_desc will be the user specified name, e.g. "fsdp". This diff updates the commsReplay to support the new schema Reviewed By: shengfukevin Differential Revision: D56288398 fbshipit-source-id: 3663e45507e098cedb407609eecc2e0cec6890a1
1 parent c8e3f2f commit 7868e09

File tree

2 files changed

+28
-21
lines changed

2 files changed

+28
-21
lines changed

train/comms/pt/commsTraceParser.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def _parseExecutionTrace(
216216
Convert the Execution Trace comms metadata to the common trace format for replay.
217217
218218
"""
219+
# Execution Trace PG_ID types availability
220+
ET_PG_NAME_TUPLE = True if in_trace.schema == "1.0.3-chakra.0.0.4" else False
221+
ET_BACKENDID = True if in_trace.schema != "1.0.3-chakra.0.0.4" else False
219222

220223
initOps = []
221224
newCommsTrace = []
@@ -233,23 +236,21 @@ def _parseExecutionTrace(
233236
break
234237

235238
for pg in pgObj:
236-
backendId = pg["uid"] if "uid" in pg else pg["backend_id"]
239+
if not pg["pg_name"].isdecimal():
240+
# TODO support local synchronization pg
241+
continue
242+
pgId = int(pg["pg_name"])
237243
ranks = pg["ranks"]
238-
if isinstance(ranks, list):
239-
pgId = int(pg["pg_name"])
240-
groupCnt = pg["group_count"]
241-
pgRanksMap[pgId] = (
242-
ranks
243-
if len(ranks) > 0
244-
else list(range(pg["group_size"]))
245-
# rank list is empty when all ranks are in a pg
246-
)
247-
elif isinstance(
248-
ranks, dict
249-
): # TODO for legacy traces: remove once all ET use the most recent pg
250-
pgId = pg["pg_id"]
251-
pgRanksMap[pgId] = [int(rank) for rank in ranks.keys()]
252-
backendIdToPgid[backendId] = pgId
244+
groupCnt = pg["group_count"]
245+
pgRanksMap[pgId] = (
246+
ranks
247+
if len(ranks) > 0
248+
else list(range(pg["group_size"]))
249+
# rank list is empty when all ranks are in a pg
250+
)
251+
if ET_BACKENDID:
252+
backendId = pg["uid"] if "uid" in pg else pg["backend_id"]
253+
backendIdToPgid[backendId] = pgId
253254
break # only one process_group init node per trace
254255

255256
# Parse comms nodes
@@ -269,12 +270,16 @@ def _parseExecutionTrace(
269270
1 - shift
270271
] # 2nd value of inputs is the req id of the collective
271272

272-
backendId = node.inputs[
273+
pgIdentifier = node.inputs[
273274
2 - shift
274-
] # 3rd value of inputs is the backend id of the collective
275-
if backendId in backendIdToPgid:
276-
# Assign pg_id info for PGs that were created.
277-
newComm.pgId = backendIdToPgid[backendId]
275+
] # 3rd value of inputs is the pg identifier of the collective
276+
# Assign pg_id info for PGs that were created.
277+
if ET_BACKENDID and pgIdentifier in backendIdToPgid:
278+
newComm.pgId = backendIdToPgid[pgIdentifier]
279+
newComm.groupRanks = pgRanksMap[newComm.pgId]
280+
newComm.worldSize = len(newComm.groupRanks)
281+
elif ET_PG_NAME_TUPLE and pgIdentifier[0].isdecimal():
282+
newComm.pgId = int(pgIdentifier[0])
278283
newComm.groupRanks = pgRanksMap[newComm.pgId]
279284
newComm.worldSize = len(newComm.groupRanks)
280285

train/compute/python/tools/execution_trace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ def __init__(self, json):
296296
node_creation_func = {
297297
"1.0.1": ExecutionTrace._create_node_v1_0_1,
298298
"1.0.2-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
299+
# 1.0.3 expands pg name to <pg_name, pg_desc> so it use the same parser as 1.0.2
300+
"1.0.3-chakra.0.0.4": ExecutionTrace._create_node_v1_0_2_chakra_0_0_4,
299301
# Add future versions here
300302
}
301303
create_node = node_creation_func.get(self.schema, None)

0 commit comments

Comments
 (0)