Skip to content

Commit

Permalink
Merge pull request #112 from trailofbits/list-args
Browse files Browse the repository at this point in the history
Add support for list arguments in python calls. Fix int encoding bug
  • Loading branch information
Boyan-MILANOV authored Jul 22, 2024
2 parents b967e4d + 6e5d692 commit 77c61c0
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
35 changes: 32 additions & 3 deletions fickling/fickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def encode_body(self) -> bytes:
st = self.struct_types[self.num_bytes]
if not self.signed:
st = st.upper()
return struct.pack(f"{self.endianness.value}{st}")
return struct.pack(f"{self.endianness.value}{st}", self.arg)

@classmethod
def validate(cls, obj):
Expand Down Expand Up @@ -415,6 +415,34 @@ def insert(self, index: int, opcode: Opcode):
self._ast = None
self._properties = None

def _is_constant_type(self, obj: Any) -> bool:
return isinstance(obj, (int, float, str, bytes))

def _encode_python_obj(self, obj: Any) -> List[Opcode]:
"""Create an opcode sequence that builds an arbitrary python object on the top of the
pickle VM stack"""
if self._is_constant_type(obj):
return [ConstantOpcode.new(obj)]
elif isinstance(obj, list):
res = [Mark()]
for item in obj:
if self._is_constant_type(item):
res.append(ConstantOpcode.new(item))
else:
res += self._encode_python_obj(item)
res.append(List())
return res
else:
raise ValueError(f"Type {type(obj)} not supported")

def insert_python_obj(self, index: int, obj: Any) -> int:
"""Insert an opcode sequence that constructs a python object on the stack.
Returns the number of opcodes inserted"""
opcodes = self._encode_python_obj(obj)
for i, opcode in enumerate(opcodes):
self.insert(index + i, opcode)
return len(opcodes)

def insert_python(
self,
*args,
Expand All @@ -440,8 +468,9 @@ def insert_python(
self.insert(i, Mark())
i += 1
for arg in args:
self.insert(i, ConstantOpcode.new(arg))
i += 1
i += self.insert_python_obj(i, arg)
# self.insert(i, ConstantOpcode.new(arg))
# i += 1
self.insert(i, Tuple())
i += 1
if run_first:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ requires-python = ">=3.8"

[project.optional-dependencies]
torch = ["torch >= 2.1.0", "torchvision >= 0.16.1"]
lint = ["black", "mypy", "ruff"]
lint = ["black", "mypy", "ruff==0.2.0"]
test = ["pytest", "pytest-cov", "coverage[toml]", "torch >= 2.1.0", "torchvision >= 0.16.1"]
dev = ["build", "fickling[lint,test]", "twine", "torch >= 2.1.0", "torchvision >= 0.16.1"]
examples = ["numpy", "pytorchfi"]
Expand Down
22 changes: 22 additions & 0 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,28 @@ def test_insert(self):
evaluated = loads(loaded.dumps())
self.assertEqual([5, 6, 7, 8], evaluated)

def test_insert_list_arg(self):
pickled = dumps([1, 2, 3, 4])
loaded = Pickled.load(pickled)
self.assertIsInstance(loaded[-1], fpickle.Stop)
loaded.insert_python(
[1, 2, ["a", "b"], 3],
module="builtins",
attr="tuple",
use_output_as_unpickle_result=True,
run_first=False,
)
self.assertIsInstance(loaded[-1], fpickle.Stop)

# Make sure the injected code cleans up the stack after itself:
interpreter = Interpreter(loaded)
interpreter.run()
self.assertEqual(len(interpreter.stack), 0)

# Make sure the output is correct
evaluated = loads(loaded.dumps())
self.assertEqual((1, 2, ["a", "b"], 3), evaluated)

def test_insert_run_last(self):
pickled = dumps([1, 2, 3, 4])
loaded = Pickled.load(pickled)
Expand Down

0 comments on commit 77c61c0

Please sign in to comment.