Skip to content

Commit b4b2504

Browse files
committed
pygit2: use update_tip callbacks instead of temp remote refspecs
1 parent a2c78a1 commit b4b2504

File tree

4 files changed

+78
-38
lines changed

4 files changed

+78
-38
lines changed

src/scmrepo/git/backend/base.py

+3
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ def fetch_refspecs(
263263
returns True the local ref will be overwritten.
264264
Callback will be of the form:
265265
on_diverged(local_refname, remote_sha)
266+
267+
Returns:
268+
Mapping of local_refname to sync status.
266269
"""
267270

268271
@abstractmethod

src/scmrepo/git/backend/pygit2/__init__.py

+64-36
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
if TYPE_CHECKING:
32-
from pygit2 import Signature
32+
from pygit2 import Oid, Signature
3333
from pygit2.remote import Remote # type: ignore
3434
from pygit2.repository import Repository
3535

@@ -551,7 +551,8 @@ def _merge_remote_branch(
551551
raise SCMError("Unknown merge analysis result")
552552

553553
@contextmanager
554-
def get_remote(self, url: str) -> Generator["Remote", None, None]:
554+
def _get_remote(self, url: str) -> Generator["Remote", None, None]:
555+
"""Return a pygit2.Remote suitable for the specified Git URL or remote name."""
555556
try:
556557
remote = self.repo.remotes[url]
557558
url = remote.url
@@ -577,57 +578,84 @@ def fetch_refspecs(
577578
progress: Callable[["GitProgressEvent"], None] = None,
578579
**kwargs,
579580
) -> Mapping[str, SyncStatus]:
581+
import fnmatch
582+
580583
from pygit2 import GitError
581584

582585
from .callbacks import RemoteCallbacks
583586

584-
if isinstance(refspecs, str):
585-
refspecs = [refspecs]
587+
refspecs = self._refspecs_list(refspecs, force=force)
586588

587-
with self.get_remote(url) as remote:
588-
fetch_refspecs: List[str] = []
589-
for refspec in refspecs:
590-
if ":" in refspec:
591-
lh, rh = refspec.split(":")
592-
else:
593-
lh = rh = refspec
594-
if not rh.startswith("refs/"):
595-
rh = f"refs/heads/{rh}"
596-
if not lh.startswith("refs/"):
597-
lh = f"refs/heads/{lh}"
598-
rh = rh[len("refs/") :]
599-
refspec = f"+{lh}:refs/remotes/{remote.name}/{rh}"
600-
fetch_refspecs.append(refspec)
601-
602-
logger.debug("fetch_refspecs: %s", fetch_refspecs)
589+
# libgit2 rejects diverged refs but does not have a callback to notify
590+
# when a ref was rejected so we have to determine whether no callback
591+
# means up to date or rejected
592+
def _default_status(
593+
src: str, dst: str, remote_refs: Dict[str, "Oid"]
594+
) -> SyncStatus:
595+
try:
596+
if remote_refs[src] != self.repo.references[dst].target:
597+
return SyncStatus.DIVERGED
598+
except KeyError:
599+
# remote_refs lookup is skipped when force is set, refs cannot
600+
# be diverged on force
601+
pass
602+
return SyncStatus.UP_TO_DATE
603+
604+
with self._get_remote(url) as remote:
603605
with reraise(
604606
GitError,
605607
SCMError(f"Git failed to fetch ref from '{url}'"),
606608
):
607609
with RemoteCallbacks(progress=progress) as cb:
610+
remote_refs: Dict[str, "Oid"] = (
611+
{
612+
head["name"]: head["oid"]
613+
for head in remote.ls_remotes(callbacks=cb)
614+
}
615+
if not force
616+
else {}
617+
)
608618
remote.fetch(
609-
refspecs=fetch_refspecs,
619+
refspecs=refspecs,
610620
callbacks=cb,
621+
message="fetch",
611622
)
612623

613624
result: Dict[str, "SyncStatus"] = {}
614-
for refspec in fetch_refspecs:
615-
_, rh = refspec.split(":")
616-
if not rh.endswith("*"):
617-
refname = rh.split("/", 3)[-1]
618-
refname = f"refs/{refname}"
619-
result[refname] = self._merge_remote_branch(
620-
rh, refname, force, on_diverged
621-
)
622-
continue
623-
rh = rh.rstrip("*").rstrip("/") + "/"
624-
for branch in self.iter_refs(base=rh):
625-
refname = f"refs/{branch[len(rh):]}"
626-
result[refname] = self._merge_remote_branch(
627-
branch, refname, force, on_diverged
628-
)
625+
for refspec in refspecs:
626+
lh, rh = refspec.split(":")
627+
if lh.endswith("*"):
628+
assert rh.endswith("*")
629+
lh_prefix = lh[:-1]
630+
rh_prefix = rh[:-1]
631+
for refname in remote_refs:
632+
if fnmatch.fnmatch(refname, lh):
633+
src = refname
634+
dst = f"{rh_prefix}{refname[len(lh_prefix):]}"
635+
result[dst] = cb.result.get(
636+
src, _default_status(src, dst, remote_refs)
637+
)
638+
else:
639+
result[rh] = cb.result.get(lh, _default_status(lh, rh, remote_refs))
640+
629641
return result
630642

643+
@staticmethod
644+
def _refspecs_list(
645+
refspecs: Union[str, Iterable[str]],
646+
force: bool = False,
647+
) -> List[str]:
648+
if isinstance(refspecs, str):
649+
if force and not refspecs.startswith("+"):
650+
refspecs = f"+{refspecs}"
651+
return [refspecs]
652+
if force:
653+
return [
654+
(refspec if refspec.startswith("+") else f"+{refspec}")
655+
for refspec in refspecs
656+
]
657+
return list(refspecs)
658+
631659
def _stash_iter(self, ref: str):
632660
raise NotImplementedError
633661

src/scmrepo/git/backend/pygit2/callbacks.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from contextlib import AbstractContextManager
22
from types import TracebackType
3-
from typing import TYPE_CHECKING, Callable, Optional, Type, Union
3+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
44

55
from pygit2 import RemoteCallbacks as _RemoteCallbacks
66

7+
from scmrepo.git.backend.base import SyncStatus
78
from scmrepo.git.credentials import Credential, CredentialNotFoundError
89
from scmrepo.progress import GitProgressReporter
910

1011
if TYPE_CHECKING:
12+
from pygit2 import Oid
1113
from pygit2.credentials import Keypair, Username, UserPass
1214

1315
from scmrepo.progress import GitProgressEvent
@@ -27,6 +29,7 @@ def __init__(
2729
self.progress = GitProgressReporter(progress) if progress else None
2830
self._store_credentials: Optional["Credential"] = None
2931
self._tried_credentials = False
32+
self.result: Dict[str, SyncStatus] = {}
3033

3134
def __exit__(
3235
self,
@@ -66,3 +69,9 @@ def credentials(
6669
def _approve_credentials(self):
6770
if self._store_credentials:
6871
self._store_credentials.approve()
72+
73+
def update_tips(self, refname: str, old: "Oid", new: "Oid"):
74+
if old == new:
75+
self.result[refname] = SyncStatus.UP_TO_DATE
76+
else:
77+
self.result[refname] = SyncStatus.SUCCESS

tests/test_pygit2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_pygit_stash_apply_conflicts(
7474
def test_pygit_ssh_error(tmp_dir: TmpDir, scm: Git, url):
7575
backend = Pygit2Backend(tmp_dir)
7676
with pytest.raises(NotImplementedError):
77-
with backend.get_remote(url):
77+
with backend._get_remote(url): # pylint: disable=protected-access
7878
pass
7979

8080

0 commit comments

Comments
 (0)