Skip to content

Commit

Permalink
Update pathways container specs for gke tpu job
Browse files Browse the repository at this point in the history
  • Loading branch information
jesus-orozco committed Jan 15, 2025
1 parent a0bf9df commit 9f19159
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions axlearn/cloud/gcp/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,11 +503,7 @@ def __init__(self, cfg: Config):
def _is_pathways_used(self) -> bool:
# identify if a job is configured to use pathways by
# checking jax_backend flag and optional import for pathways utils
# brittle implementation
return (
"pathwaysutils" in self.config.import_modules
and "jax_backend proxy" in self.config.command
)
return "jax_backend=proxy" in self.config.command.replace(" ", "=")

def _import_modules(self):
try:
Expand Down Expand Up @@ -726,7 +722,7 @@ def _build_pathways_containers(self) -> list[dict]:
"""
cfg: TPUGKEJob.Config = self.config
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
staging_location = f"{cfg.output_dir}/pathways-staging/tmp"
staging_location = f"{cfg.output_dir}/pathways-staging"
tpu_type = self._get_pathways_tpu_type(system.device_type)

return [
Expand Down Expand Up @@ -928,7 +924,7 @@ def _build_pod(self, job_type: str = None) -> Nested[Any]:

if job_type == "pathways-head":
# Target a specific CPU nodepool for Pathways containers
selector.update({"cloud.google.com/gke-nodepool": "pathways-head"})
selector.update({"node.kubernetes.io/instance-type": "n2-standard-32"})
initContainers.extend(self._build_pathways_containers())
else:
selector.update(
Expand Down

0 comments on commit 9f19159

Please sign in to comment.