Skip to content
This repository has been archived by the owner on Mar 6, 2024. It is now read-only.

Commit

Permalink
Fix spacing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieJiangXXX committed Jan 31, 2024
1 parent ad9f93f commit a85a947
Showing 1 changed file with 37 additions and 22 deletions.
59 changes: 37 additions & 22 deletions nest_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class NestedAsyncIO:
"orig_tasks",
"orig_futures",
"orig_loop_attrs",
"policy_get_loop",
"orig_get_loops",
"orig_tc",
"patched"
Expand All @@ -35,6 +36,7 @@ def __init__(self, loop=None):
self.orig_tasks = []
self.orig_futures = []
self.orig_loop_attrs = {}
self.policy_get_loop = None
self.orig_get_loops = {}
self.orig_tc = None
self.patched = False
Expand Down Expand Up @@ -82,25 +84,32 @@ def run(main, *, debug=False):
with suppress(asyncio.CancelledError):
loop.run_until_complete(task)

def _get_event_loop(stacklevel=3):
return (events._get_running_loop() or
events.get_event_loop_policy().get_event_loop())

# Use module level _current_tasks, all_tasks and patch run method.
if getattr(asyncio, '_nest_patched', False):
return
if sys.version_info >= (3, 6, 0):
self.orig_tasks = [asyncio.Task, asyncio.tasks._CTask, asyncio.tasks.Task]
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \
asyncio.tasks._PyTask
self.orig_futures = [asyncio.Future, asyncio.futures._CFuture, asyncio.futures.Future]
asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = \
asyncio.futures._PyFuture
self.orig_tasks = [asyncio.Task, asyncio.tasks._CTask,
asyncio.tasks.Task]
asyncio.Task = asyncio.tasks._CTask = \
asyncio.tasks.Task = asyncio.tasks._PyTask
self.orig_futures = [asyncio.Future, asyncio.futures._CFuture,
asyncio.futures.Future]
asyncio.Future = asyncio.futures._CFuture = \
asyncio.futures.Future = asyncio.futures._PyFuture
if sys.version_info < (3, 7, 0):
asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks
asyncio.all_tasks = asyncio.tasks.Task.all_tasks
elif sys.version_info >= (3, 9, 0):
self.orig_get_loops["events__get_event_loop"] = events._get_event_loop
self.orig_get_loops["events_get_event_loop"] = events.get_event_loop
self.orig_get_loops["asyncio_get_event_loop"] = asyncio.get_event_loop
events._get_event_loop = events.get_event_loop = asyncio.get_event_loop = \
lambda stacklevel = 3: events._get_running_loop() or events.get_event_loop_policy().get_event_loop()
self.orig_get_loops = \
{"events__get_event_loop": events._get_event_loop,
"events_get_event_loop": events.get_event_loop,
"asyncio_get_event_loop": asyncio.get_event_loop}
events._get_event_loop = events.get_event_loop = \
asyncio.get_event_loop = _get_event_loop
self.orig_run = asyncio.run
asyncio.run = run
asyncio._nest_patched = True
Expand All @@ -110,15 +119,18 @@ def unpatch_asyncio(self):
asyncio.run = self.orig_run
asyncio._nest_patched = False
if sys.version_info >= (3, 6, 0):
asyncio.Task, asyncio.tasks._CTask, asyncio.tasks.Task = self.orig_tasks
asyncio.Future, asyncio.futures._CFuture, asyncio.futures.Future = self.orig_futures
(asyncio.Task, asyncio.tasks._CTask,
asyncio.tasks.Task) = self.orig_tasks
(asyncio.Future, asyncio.futures._CFuture,
asyncio.futures.Future) = self.orig_futures
if sys.version_info >= (3, 9, 0):
events._get_event_loop = self.orig_get_loops["events__get_event_loop"]
events.get_event_loop = self.orig_get_loops["events_get_event_loop"]
asyncio.get_event_loop = self.orig_get_loops["asyncio_get_event_loop"]
for key, value in self.orig_get_loops.items():
setattr(asyncio if key.startswith('asyncio')
else events, key.split('_')[-1], value)

def patch_policy(self):
"""Patch the policy to always return a patched loop."""

def get_event_loop(this):
if this._local._loop is None:
loop = this.new_event_loop()
Expand All @@ -127,12 +139,13 @@ def get_event_loop(this):
return this._local._loop

cls = events.get_event_loop_policy().__class__
self.orig_get_loops[f"{cls}.get_event_loop"] = cls.get_event_loop
self.policy_get_loop = cls.get_event_loop
cls.get_event_loop = get_event_loop

def unpatch_policy(self):
cls = events.get_event_loop_policy().__class__
if orig := self.orig_get_loops[f"{cls}.get_event_loop"]:
orig = self.policy_get_loop
if orig:
cls.get_event_loop = orig

def patch_loop(self, loop):
Expand Down Expand Up @@ -258,7 +271,8 @@ def _check_running(this):
self.orig_loop_attrs[cls] = {}
self.orig_loop_attrs[cls]["run_forever"] = cls.run_forever
cls.run_forever = run_forever
self.orig_loop_attrs[cls]["run_until_complete"] = cls.run_until_complete
self.orig_loop_attrs[cls]["run_until_complete"] = \
cls.run_until_complete
cls.run_until_complete = run_until_complete
self.orig_loop_attrs[cls]["_run_once"] = cls._run_once
cls._run_once = _run_once
Expand All @@ -267,11 +281,12 @@ def _check_running(this):
self.orig_loop_attrs[cls]["_check_runnung"] = cls._check_running
cls._check_runnung = _check_running # typo in Python 3.7 source
cls._num_runs_pending = 1 if loop.is_running() else 0
cls._is_proactorloop = (
os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop))
cls._is_proactorloop = (os.name == 'nt' and
issubclass(cls, asyncio.ProactorEventLoop))
if sys.version_info < (3, 7, 0):
cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
curr_tasks = asyncio.tasks._current_tasks if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks
curr_tasks = asyncio.tasks._current_tasks \
if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks
cls._nest_patched = True

def unpatch_loop(self, loop):
Expand Down

0 comments on commit a85a947

Please sign in to comment.