diff --git a/lnprototest/runner.py b/lnprototest/runner.py index 2cda7a6..0ef5a40 100644 --- a/lnprototest/runner.py +++ b/lnprototest/runner.py @@ -7,17 +7,16 @@ import coincurve import functools -import pyln -from pyln.proto.message import Message - from abc import ABC, abstractmethod from typing import Dict, Optional, List, Union, Any, Callable +from pyln.proto.message import Message + from .bitfield import bitfield from .errors import SpecFileError from .structure import Sequence from .event import Event, MustNotMsg, ExpectMsg -from .utils import privkey_expand +from .utils import privkey_expand, ResolvableStr, ResolvableInt, resolve_args from .keyset import KeySet from .namespace import namespace @@ -78,6 +77,33 @@ def get_stash(self, event: Event, stashname: str, default: Any = None) -> Any: raise SpecFileError(event, "Unknown stash name {}".format(stashname)) return self.stash[stashname] + def recv_msg( + self, timeout: int = 1000, skip_filter: Optional[int] = None + ) -> Message: + """Listen on the connection for incoming message. + + If the {skip_filter} is specified, the message that + match the filters are skipped. + """ + raw_msg = self.connection.read_message() + msg = Message.read(namespace(), io.BytesIO(raw_msg)) + self.add_stash(msg.messagetype.name, msg) + return msg + + def send_msg( + self, msg_name: str, **kwargs: Union[ResolvableStr, ResolvableInt] + ) -> None: + """Send a message through the last connection""" + msgtype = namespace().get_msgtype(msg_name) + msg = Message(msgtype, **resolve_args(self, kwargs)) + missing = msg.missing_fields() + if missing: + raise SpecFileError(self, "Missing fields {}".format(missing)) + binmsg = io.BytesIO() + msg.write(binmsg) + self.connection.send_message(binmsg.getvalue()) + # FIXME: we should listen to possible connection here + class Runner(ABC): """Abstract base class for runners. @@ -189,7 +215,7 @@ def is_running(self) -> bool: pass @abstractmethod - def connect(self, event: Event, connprivkey: str) -> None: + def connect(self, event: Event, connprivkey: str) -> RunnerConn: pass def send_msg(self, msg: Message) -> None: diff --git a/lnprototest/utils/__init__.py b/lnprototest/utils/__init__.py index ea3025c..790d473 100644 --- a/lnprototest/utils/__init__.py +++ b/lnprototest/utils/__init__.py @@ -14,6 +14,12 @@ check_hex, privkey_for_index, merge_events_sequences, + Resolvable, + ResolvableBool, + ResolvableInt, + ResolvableStr, + resolve_arg, + resolve_args, ) from .bitcoin_utils import ( ScriptType, diff --git a/lnprototest/utils/utils.py b/lnprototest/utils/utils.py index 4b28bd3..3d1fc16 100644 --- a/lnprototest/utils/utils.py +++ b/lnprototest/utils/utils.py @@ -8,11 +8,17 @@ import logging import traceback -from typing import Union, Sequence, List +from typing import Union, Sequence, List, Dict, Callable, Any from enum import IntEnum from lnprototest.keyset import KeySet +# Type for arguments: either strings, or functions to call at runtime +ResolvableStr = Union[str, Callable[["RunnerConn", "Event", str], str]] +ResolvableInt = Union[int, Callable[["RunnerConn", "Event", str], int]] +ResolvableBool = Union[int, Callable[["RunnerConn", "Event", str], bool]] +Resolvable = Union[Any, Callable[["RunnerConn", "Event", str], Any]] + class Side(IntEnum): local = 0 @@ -106,3 +112,19 @@ def merge_events_sequences( """Merge the two list in the pre-post order""" pre.extend(post) return pre + + +def resolve_arg(fieldname: str, conn: "RunnerConn", arg: Resolvable) -> Any: + """If this is a string, return it, otherwise call it to get result""" + if callable(arg): + return arg(conn, fieldname) + else: + return arg + + +def resolve_args(conn: "RunnerConn", kwargs: Dict[str, Resolvable]) -> Dict[str, Any]: + """Take a dict of args, replace callables with their return values""" + ret: Dict[str, str] = {} + for field, str_or_func in kwargs.items(): + ret[field] = resolve_arg(field, conn, str_or_func) + return ret diff --git a/tests/test_v2_bolt1-01-init.py b/tests/test_v2_bolt1-01-init.py index 38c00e2..42646b1 100644 --- a/tests/test_v2_bolt1-01-init.py +++ b/tests/test_v2_bolt1-01-init.py @@ -11,10 +11,10 @@ def test_v2_init_is_first_msg(runner: Runner, namespaceoverride: Any) -> None: """ runner.start() - runner.connect(None, connprivkey="03") - init_msg = runner.recv_msg() + conn1 = runner.connect(None, connprivkey="03") + init_msg = conn1.recv_msg() assert ( init_msg.messagetype.number == 16 ), f"received not an init msg but: {init_msg.to_str()}" - + conn1.send_msg("init", globalfeatures="", features="") runner.stop()