Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions mypy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections.abc import Sequence
from typing import TYPE_CHECKING, Final

from mypy_extensions import u8

try:
from native_internal import (
Buffer as Buffer,
Expand Down Expand Up @@ -34,10 +36,10 @@ def read_int(data: Buffer) -> int:
def write_int(data: Buffer, value: int) -> None:
raise NotImplementedError

def read_tag(data: Buffer) -> int:
def read_tag(data: Buffer) -> u8:
raise NotImplementedError

def write_tag(data: Buffer, value: int) -> None:
def write_tag(data: Buffer, value: u8) -> None:
raise NotImplementedError

def read_str(data: Buffer) -> str:
Expand All @@ -59,15 +61,18 @@ def write_float(data: Buffer, value: float) -> None:
raise NotImplementedError


LITERAL_INT: Final = 1
LITERAL_STR: Final = 2
LITERAL_BOOL: Final = 3
LITERAL_FLOAT: Final = 4
LITERAL_COMPLEX: Final = 5
LITERAL_NONE: Final = 6
# Always use this type alias to refer to type tags.
Tag = u8

LITERAL_INT: Final[Tag] = 1
LITERAL_STR: Final[Tag] = 2
LITERAL_BOOL: Final[Tag] = 3
LITERAL_FLOAT: Final[Tag] = 4
LITERAL_COMPLEX: Final[Tag] = 5
LITERAL_NONE: Final[Tag] = 6


def read_literal(data: Buffer, tag: int) -> int | str | bool | float:
def read_literal(data: Buffer, tag: Tag) -> int | str | bool | float:
if tag == LITERAL_INT:
return read_int(data)
elif tag == LITERAL_STR:
Expand Down
23 changes: 12 additions & 11 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
LITERAL_COMPLEX,
LITERAL_NONE,
Buffer,
Tag,
read_bool,
read_float,
read_int,
Expand Down Expand Up @@ -4877,17 +4878,17 @@ def local_definitions(
yield from local_definitions(node.names, fullname, node)


MYPY_FILE: Final = 0
OVERLOADED_FUNC_DEF: Final = 1
FUNC_DEF: Final = 2
DECORATOR: Final = 3
VAR: Final = 4
TYPE_VAR_EXPR: Final = 5
PARAM_SPEC_EXPR: Final = 6
TYPE_VAR_TUPLE_EXPR: Final = 7
TYPE_INFO: Final = 8
TYPE_ALIAS: Final = 9
CLASS_DEF: Final = 10
MYPY_FILE: Final[Tag] = 0
OVERLOADED_FUNC_DEF: Final[Tag] = 1
FUNC_DEF: Final[Tag] = 2
DECORATOR: Final[Tag] = 3
VAR: Final[Tag] = 4
TYPE_VAR_EXPR: Final[Tag] = 5
PARAM_SPEC_EXPR: Final[Tag] = 6
TYPE_VAR_TUPLE_EXPR: Final[Tag] = 7
TYPE_INFO: Final[Tag] = 8
TYPE_ALIAS: Final[Tag] = 9
CLASS_DEF: Final[Tag] = 10


def read_symbol(data: Buffer) -> mypy.nodes.SymbolNode:
Expand Down
39 changes: 20 additions & 19 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mypy.bogus_type import Bogus
from mypy.cache import (
Buffer,
Tag,
read_bool,
read_int,
read_int_list,
Expand Down Expand Up @@ -4120,25 +4121,25 @@ def type_vars_as_args(type_vars: Sequence[TypeVarLikeType]) -> tuple[Type, ...]:
return tuple(args)


TYPE_ALIAS_TYPE: Final = 1
TYPE_VAR_TYPE: Final = 2
PARAM_SPEC_TYPE: Final = 3
TYPE_VAR_TUPLE_TYPE: Final = 4
UNBOUND_TYPE: Final = 5
UNPACK_TYPE: Final = 6
ANY_TYPE: Final = 7
UNINHABITED_TYPE: Final = 8
NONE_TYPE: Final = 9
DELETED_TYPE: Final = 10
INSTANCE: Final = 11
CALLABLE_TYPE: Final = 12
OVERLOADED: Final = 13
TUPLE_TYPE: Final = 14
TYPED_DICT_TYPE: Final = 15
LITERAL_TYPE: Final = 16
UNION_TYPE: Final = 17
TYPE_TYPE: Final = 18
PARAMETERS: Final = 19
TYPE_ALIAS_TYPE: Final[Tag] = 1
TYPE_VAR_TYPE: Final[Tag] = 2
PARAM_SPEC_TYPE: Final[Tag] = 3
TYPE_VAR_TUPLE_TYPE: Final[Tag] = 4
UNBOUND_TYPE: Final[Tag] = 5
UNPACK_TYPE: Final[Tag] = 6
ANY_TYPE: Final[Tag] = 7
UNINHABITED_TYPE: Final[Tag] = 8
NONE_TYPE: Final[Tag] = 9
DELETED_TYPE: Final[Tag] = 10
INSTANCE: Final[Tag] = 11
CALLABLE_TYPE: Final[Tag] = 12
OVERLOADED: Final[Tag] = 13
TUPLE_TYPE: Final[Tag] = 14
TYPED_DICT_TYPE: Final[Tag] = 15
LITERAL_TYPE: Final[Tag] = 16
UNION_TYPE: Final[Tag] = 17
TYPE_TYPE: Final[Tag] = 18
PARAMETERS: Final[Tag] = 19


def read_type(data: Buffer) -> Type:
Expand Down
6 changes: 4 additions & 2 deletions mypy/typeshed/stubs/mypy-native/native_internal.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mypy_extensions import u8

class Buffer:
def __init__(self, source: bytes = ...) -> None: ...
def getvalue(self) -> bytes: ...
Expand All @@ -10,5 +12,5 @@ def write_float(data: Buffer, value: float) -> None: ...
def read_float(data: Buffer) -> float: ...
def write_int(data: Buffer, value: int) -> None: ...
def read_int(data: Buffer) -> int: ...
def write_tag(data: Buffer, value: int) -> None: ...
def read_tag(data: Buffer) -> int: ...
def write_tag(data: Buffer, value: u8) -> None: ...
def read_tag(data: Buffer) -> u8: ...
31 changes: 13 additions & 18 deletions mypyc/lib-rt/native_internal.c
Original file line number Diff line number Diff line change
Expand Up @@ -438,18 +438,18 @@ write_int(PyObject *self, PyObject *args, PyObject *kwds) {
return Py_None;
}

static CPyTagged
static uint8_t
read_tag_internal(PyObject *data) {
if (_check_buffer(data) == 2)
return CPY_INT_TAG;
return CPY_LL_UINT_ERROR;

if (_check_read((BufferObject *)data, 1) == 2)
return CPY_INT_TAG;
return CPY_LL_UINT_ERROR;
char *buf = ((BufferObject *)data)->buf;

uint8_t ret = *(uint8_t *)(buf + ((BufferObject *)data)->pos);
((BufferObject *)data)->pos += 1;
return ((CPyTagged)ret) << 1;
return ret;
}

static PyObject*
Expand All @@ -458,27 +458,22 @@ read_tag(PyObject *self, PyObject *args, PyObject *kwds) {
PyObject *data = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist, &data))
return NULL;
CPyTagged retval = read_tag_internal(data);
if (retval == CPY_INT_TAG) {
uint8_t retval = read_tag_internal(data);
if (retval == CPY_LL_UINT_ERROR && PyErr_Occurred()) {
return NULL;
}
return CPyTagged_StealAsObject(retval);
return PyLong_FromLong(retval);
}

static char
write_tag_internal(PyObject *data, CPyTagged value) {
write_tag_internal(PyObject *data, uint8_t value) {
if (_check_buffer(data) == 2)
return 2;

if (value > MAX_SHORT_INT_TAGGED) {
PyErr_SetString(PyExc_OverflowError, "value must fit in single byte");
return 2;
}

if (_check_size((BufferObject *)data, 1) == 2)
return 2;
uint8_t *buf = (uint8_t *)((BufferObject *)data)->buf;
*(buf + ((BufferObject *)data)->pos) = (uint8_t)(value >> 1);
*(buf + ((BufferObject *)data)->pos) = value;
((BufferObject *)data)->pos += 1;
((BufferObject *)data)->end += 1;
return 1;
Expand All @@ -491,12 +486,12 @@ write_tag(PyObject *self, PyObject *args, PyObject *kwds) {
PyObject *value = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO", kwlist, &data, &value))
return NULL;
if (!PyLong_Check(value)) {
PyErr_SetString(PyExc_TypeError, "value must be an int");
uint8_t unboxed = CPyLong_AsUInt8(value);
if (unboxed == CPY_LL_UINT_ERROR && PyErr_Occurred()) {
CPy_TypeError("u8", value);
return NULL;
}
CPyTagged tagged_value = CPyTagged_BorrowFromObject(value);
if (write_tag_internal(data, tagged_value) == 2) {
if (write_tag_internal(data, unboxed) == 2) {
return NULL;
}
Py_INCREF(Py_None);
Expand Down
8 changes: 4 additions & 4 deletions mypyc/lib-rt/native_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ static char write_float_internal(PyObject *data, double value);
static double read_float_internal(PyObject *data);
static char write_int_internal(PyObject *data, CPyTagged value);
static CPyTagged read_int_internal(PyObject *data);
static char write_tag_internal(PyObject *data, CPyTagged value);
static CPyTagged read_tag_internal(PyObject *data);
static char write_tag_internal(PyObject *data, uint8_t value);
static uint8_t read_tag_internal(PyObject *data);
static int NativeInternal_ABI_Version(void);

#else
Expand All @@ -35,8 +35,8 @@ static void **NativeInternal_API;
#define read_float_internal (*(double (*)(PyObject *source)) NativeInternal_API[8])
#define write_int_internal (*(char (*)(PyObject *source, CPyTagged value)) NativeInternal_API[9])
#define read_int_internal (*(CPyTagged (*)(PyObject *source)) NativeInternal_API[10])
#define write_tag_internal (*(char (*)(PyObject *source, CPyTagged value)) NativeInternal_API[11])
#define read_tag_internal (*(CPyTagged (*)(PyObject *source)) NativeInternal_API[12])
#define write_tag_internal (*(char (*)(PyObject *source, uint8_t value)) NativeInternal_API[11])
#define read_tag_internal (*(uint8_t (*)(PyObject *source)) NativeInternal_API[12])
#define NativeInternal_ABI_Version (*(int (*)(void)) NativeInternal_API[13])

static int
Expand Down
11 changes: 6 additions & 5 deletions mypyc/primitives/misc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_NEVER
from mypyc.ir.ops import ERR_FALSE, ERR_MAGIC, ERR_MAGIC_OVERLAPPING, ERR_NEVER
from mypyc.ir.rtypes import (
KNOWN_NATIVE_TYPES,
bit_rprimitive,
Expand All @@ -20,6 +20,7 @@
object_rprimitive,
pointer_rprimitive,
str_rprimitive,
uint8_rprimitive,
void_rtype,
)
from mypyc.primitives.registry import (
Expand Down Expand Up @@ -426,16 +427,16 @@

function_op(
name="native_internal.write_tag",
arg_types=[object_rprimitive, int_rprimitive],
arg_types=[object_rprimitive, uint8_rprimitive],
return_type=none_rprimitive,
c_function_name="write_tag_internal",
error_kind=ERR_MAGIC,
error_kind=ERR_MAGIC_OVERLAPPING,
)

function_op(
name="native_internal.read_tag",
arg_types=[object_rprimitive],
return_type=int_rprimitive,
return_type=uint8_rprimitive,
c_function_name="read_tag_internal",
error_kind=ERR_MAGIC,
error_kind=ERR_MAGIC_OVERLAPPING,
)
12 changes: 9 additions & 3 deletions mypyc/test-data/irbuild-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -1410,18 +1410,23 @@ class TestOverload:
return x

[case testNativeBufferFastPath]
from typing import Final
from mypy_extensions import u8
from native_internal import (
Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float,
write_int, read_int, write_tag, read_tag
)

Tag = u8
TAG: Final[Tag] = 1

def foo() -> None:
b = Buffer()
write_str(b, "foo")
write_bool(b, True)
write_float(b, 0.1)
write_int(b, 1)
write_tag(b, 1)
write_tag(b, TAG)

b = Buffer(b.getvalue())
x = read_str(b)
Expand All @@ -1439,7 +1444,8 @@ def foo():
r9, x :: str
r10, y :: bool
r11, z :: float
r12, t, r13, u :: int
r12, t :: int
r13, u :: u8
L0:
r0 = Buffer_internal_empty()
b = r0
Expand All @@ -1448,7 +1454,7 @@ L0:
r3 = write_bool_internal(b, 1)
r4 = write_float_internal(b, 0.1)
r5 = write_int_internal(b, 2)
r6 = write_tag_internal(b, 2)
r6 = write_tag_internal(b, 1)
r7 = Buffer_getvalue_internal(b)
r8 = Buffer_internal(r7)
b = r8
Expand Down
19 changes: 15 additions & 4 deletions mypyc/test-data/run-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2711,11 +2711,18 @@ from native import Player
Player.MIN = <Player.MIN: 1>

[case testBufferRoundTrip_native_libs]
from typing import Final
from mypy_extensions import u8
from native_internal import (
Buffer, write_bool, read_bool, write_str, read_str, write_float, read_float,
write_int, read_int, write_tag, read_tag
)

Tag = u8
TAG_A: Final[Tag] = 33
TAG_B: Final[Tag] = 255
TAG_SPECIAL: Final[Tag] = 239

def test_buffer_basic() -> None:
b = Buffer(b"foo")
assert b.getvalue() == b"foo"
Expand All @@ -2729,8 +2736,9 @@ def test_buffer_roundtrip() -> None:
write_float(b, 0.1)
write_int(b, 0)
write_int(b, 1)
write_tag(b, 33)
write_tag(b, 255)
write_tag(b, TAG_A)
write_tag(b, TAG_SPECIAL)
write_tag(b, TAG_B)
write_int(b, 2)
write_int(b, 2 ** 85)
write_int(b, -1)
Expand All @@ -2743,8 +2751,9 @@ def test_buffer_roundtrip() -> None:
assert read_float(b) == 0.1
assert read_int(b) == 0
assert read_int(b) == 1
assert read_tag(b) == 33
assert read_tag(b) == 255
assert read_tag(b) == TAG_A
assert read_tag(b) == TAG_SPECIAL
assert read_tag(b) == TAG_B
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try reading the error value (when there is no error).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, good idea, I will add it to both compiled and interpreted tests.

assert read_int(b) == 2
assert read_int(b) == 2 ** 85
assert read_int(b) == -1
Expand All @@ -2769,6 +2778,7 @@ def test_buffer_roundtrip_interpreted() -> None:
write_int(b, 0)
write_int(b, 1)
write_tag(b, 33)
write_tag(b, 239)
write_tag(b, 255)
write_int(b, 2)
write_int(b, 2 ** 85)
Expand All @@ -2783,6 +2793,7 @@ def test_buffer_roundtrip_interpreted() -> None:
assert read_int(b) == 0
assert read_int(b) == 1
assert read_tag(b) == 33
assert read_tag(b) == 239
assert read_tag(b) == 255
assert read_int(b) == 2
assert read_int(b) == 2 ** 85
Expand Down
Loading
Loading