Skip to content

Commit

Permalink
Add version selection snippet for r1.7 (#2577)
Browse files Browse the repository at this point in the history
* Add version selection snippet for r1.7

* VSCode artifact ignore
  • Loading branch information
jysohn23 authored Oct 27, 2020
1 parent 3f8c5dd commit 7231272
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ torch_xla/csrc/aten_xla_type_default.cpp
# Below files are not deleted by "setup.py clean".

third_party/tensorflow/

# Visual Studio Code files
.vscode
.vs
54 changes: 53 additions & 1 deletion torch_xla/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,57 @@
import logging
import os
import re
import socket
import time

from .version import __version__


def _maybe_select_tpu_version():
# Setup correct TPU runtime version for Colab and Kaggle.

def _is_open(ip, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if s.connect_ex((ip, int(port))) == 0:
return True
return False

def _wait_for_open(version, timeout=100, interval=10, log=True):
tpu_addr = os.environ['TPU_NAME'].split('grpc://')[1]
deadline = time.time() + timeout

while not _is_open(*tpu_addr.split(':')):
if log:
logging.warning(
f'Waiting for TPU to be start up with version pytorch-{version}...')
if time.time() > deadline:
raise RuntimeError('Timed out waiting for TPU to start up')
time.sleep(interval)

if log:
logging.warning(
f'TPU has started up successfully with version pytorch-{version}')

try:
tpu_name = os.environ.get('TPU_NAME', '')
if not tpu_name.startswith('grpc://'):
# Not colab/kaggle
return

import cloud_tpu_client
client = cloud_tpu_client.Client(tpu_name)
client.configure_tpu_version(
f'pytorch-{__version__}', restart_type='ifNeeded')
# client.wait_for_healthy() API doesn't work as we dont have TPU API access
_wait_for_open(__version__)
except ImportError:
logging.warning((
'Not selecting corresponding TPU runtime since cloud_tpu_client is not '
'installed. Ignore if not running on Colab/Kaggle TPU.'))
except Exception:
# This path is hit, when we get throttled by the verison changer
# when we import torch_xla from xmp.spawn-ed processes.
_wait_for_open(__version__, log=False)


def _setup_grpc():
Expand Down Expand Up @@ -33,13 +85,13 @@ def _setup_xla_flags():


# These needs to be called before the _XLAC module is loaded.
_maybe_select_tpu_version()
_setup_grpc()
_setup_xla_flags()

import atexit
import torch
from ._patched_functions import _apply_patches
from .version import __version__
import _XLAC


Expand Down

0 comments on commit 7231272

Please sign in to comment.