From 738689ffa22539cd113c1ebf8fb3e691bc141137 Mon Sep 17 00:00:00 2001 From: Vincenzo Palazzo Date: Sat, 22 Jul 2023 19:46:19 +0200 Subject: [PATCH] conn: Add the possibility to send and receive message through the connection This commit enable the connection to send a message through the wire in an ergonomic way. This feature is a basic blocks for the lnprototest refactoring that allow to semplify how to write test with lnprototest in the future by keeping the state with the peer by connection and keep inside the runner just the necessary logic to interact with the node. Signed-off-by: Vincenzo Palazzo --- lnprototest/runner.py | 36 +++++++++++++++++++++++++++++----- lnprototest/utils/__init__.py | 6 ++++++ lnprototest/utils/utils.py | 24 ++++++++++++++++++++++- tests/test_v2_bolt1-01-init.py | 6 +++--- 4 files changed, 63 insertions(+), 9 deletions(-) diff --git a/lnprototest/runner.py b/lnprototest/runner.py index 2cda7a6..0ef5a40 100644 --- a/lnprototest/runner.py +++ b/lnprototest/runner.py @@ -7,17 +7,16 @@ import coincurve import functools -import pyln -from pyln.proto.message import Message - from abc import ABC, abstractmethod from typing import Dict, Optional, List, Union, Any, Callable +from pyln.proto.message import Message + from .bitfield import bitfield from .errors import SpecFileError from .structure import Sequence from .event import Event, MustNotMsg, ExpectMsg -from .utils import privkey_expand +from .utils import privkey_expand, ResolvableStr, ResolvableInt, resolve_args from .keyset import KeySet from .namespace import namespace @@ -78,6 +77,33 @@ def get_stash(self, event: Event, stashname: str, default: Any = None) -> Any: raise SpecFileError(event, "Unknown stash name {}".format(stashname)) return self.stash[stashname] + def recv_msg( + self, timeout: int = 1000, skip_filter: Optional[int] = None + ) -> Message: + """Listen on the connection for incoming message. + + If the {skip_filter} is specified, the message that + match the filters are skipped. + """ + raw_msg = self.connection.read_message() + msg = Message.read(namespace(), io.BytesIO(raw_msg)) + self.add_stash(msg.messagetype.name, msg) + return msg + + def send_msg( + self, msg_name: str, **kwargs: Union[ResolvableStr, ResolvableInt] + ) -> None: + """Send a message through the last connection""" + msgtype = namespace().get_msgtype(msg_name) + msg = Message(msgtype, **resolve_args(self, kwargs)) + missing = msg.missing_fields() + if missing: + raise SpecFileError(self, "Missing fields {}".format(missing)) + binmsg = io.BytesIO() + msg.write(binmsg) + self.connection.send_message(binmsg.getvalue()) + # FIXME: we should listen to possible connection here + class Runner(ABC): """Abstract base class for runners. @@ -189,7 +215,7 @@ def is_running(self) -> bool: pass @abstractmethod - def connect(self, event: Event, connprivkey: str) -> None: + def connect(self, event: Event, connprivkey: str) -> RunnerConn: pass def send_msg(self, msg: Message) -> None: diff --git a/lnprototest/utils/__init__.py b/lnprototest/utils/__init__.py index ea3025c..790d473 100644 --- a/lnprototest/utils/__init__.py +++ b/lnprototest/utils/__init__.py @@ -14,6 +14,12 @@ check_hex, privkey_for_index, merge_events_sequences, + Resolvable, + ResolvableBool, + ResolvableInt, + ResolvableStr, + resolve_arg, + resolve_args, ) from .bitcoin_utils import ( ScriptType, diff --git a/lnprototest/utils/utils.py b/lnprototest/utils/utils.py index 4b28bd3..3d1fc16 100644 --- a/lnprototest/utils/utils.py +++ b/lnprototest/utils/utils.py @@ -8,11 +8,17 @@ import logging import traceback -from typing import Union, Sequence, List +from typing import Union, Sequence, List, Dict, Callable, Any from enum import IntEnum from lnprototest.keyset import KeySet +# Type for arguments: either strings, or functions to call at runtime +ResolvableStr = Union[str, Callable[["RunnerConn", "Event", str], str]] +ResolvableInt = Union[int, Callable[["RunnerConn", "Event", str], int]] +ResolvableBool = Union[int, Callable[["RunnerConn", "Event", str], bool]] +Resolvable = Union[Any, Callable[["RunnerConn", "Event", str], Any]] + class Side(IntEnum): local = 0 @@ -106,3 +112,19 @@ def merge_events_sequences( """Merge the two list in the pre-post order""" pre.extend(post) return pre + + +def resolve_arg(fieldname: str, conn: "RunnerConn", arg: Resolvable) -> Any: + """If this is a string, return it, otherwise call it to get result""" + if callable(arg): + return arg(conn, fieldname) + else: + return arg + + +def resolve_args(conn: "RunnerConn", kwargs: Dict[str, Resolvable]) -> Dict[str, Any]: + """Take a dict of args, replace callables with their return values""" + ret: Dict[str, str] = {} + for field, str_or_func in kwargs.items(): + ret[field] = resolve_arg(field, conn, str_or_func) + return ret diff --git a/tests/test_v2_bolt1-01-init.py b/tests/test_v2_bolt1-01-init.py index 38c00e2..42646b1 100644 --- a/tests/test_v2_bolt1-01-init.py +++ b/tests/test_v2_bolt1-01-init.py @@ -11,10 +11,10 @@ def test_v2_init_is_first_msg(runner: Runner, namespaceoverride: Any) -> None: """ runner.start() - runner.connect(None, connprivkey="03") - init_msg = runner.recv_msg() + conn1 = runner.connect(None, connprivkey="03") + init_msg = conn1.recv_msg() assert ( init_msg.messagetype.number == 16 ), f"received not an init msg but: {init_msg.to_str()}" - + conn1.send_msg("init", globalfeatures="", features="") runner.stop()