Skip to content

Commit 60ea636

Browse files
authored
Merge pull request #10 from ketgo/issue-9
Fix for parsing marshmallow nested field with many=True argument.
2 parents 03718cf + 3efcd3a commit 60ea636

File tree

5 files changed

+30
-5
lines changed

5 files changed

+30
-5
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
strategy:
1414
fail-fast: false
1515
matrix:
16-
python-version: ["3.6", "3.7", "3.8"]
16+
python-version: ["3.7", "3.8"]
1717

1818
steps:
1919
- uses: actions/checkout@v3

.gitignore

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,12 @@ fabric.properties
7373

7474
.idea/*
7575

76+
# Python files and dirs to ignore
7677
venv/*
77-
7878
*.egg-info
79+
*.pyc
80+
__pycache__/
81+
dist/
7982

8083
.coverage
8184

marshmallow_pyspark/converters.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,4 @@ def convert(self, ma_field: ma_fields.Nested) -> DataType:
156156
nullable=True
157157
)
158158
)
159-
return StructType(_fields)
159+
return ArrayType(StructType(_fields)) if ma_field.many else StructType(_fields)

marshmallow_pyspark/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Version for marshmallow_pyspark package
33
"""
44

5-
__version__ = '0.2.3' # pragma: no cover
5+
__version__ = '0.2.4' # pragma: no cover
66

77

88
def version_info(): # pragma: no cover

tests/test_schema.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def test_create():
2929
(fields.Integer(), IntegerType()),
3030
(fields.Number(), DoubleType()),
3131
(fields.List(fields.String()), ArrayType(StringType())),
32-
(fields.Nested(Schema.from_dict({"name": fields.String()})), StructType([StructField("name", StringType())]))
32+
(fields.Nested(Schema.from_dict({"name": fields.String()})), StructType([StructField("name", StringType())])),
33+
(fields.Nested(Schema.from_dict({"name": fields.String()}), many=True),
34+
ArrayType(StructType([StructField("name", StringType())])))
3335
])
3436
def test_spark_schema(ma_field, spark_field):
3537
class TestSchema(Schema):
@@ -110,6 +112,26 @@ class TestSchema(Schema):
110112
[
111113
{"name": "invalid_1", "book": {"author": "Sam", "title": "Sam's Book", "cost": "32a"}},
112114
]
115+
),
116+
(
117+
Schema.from_dict({
118+
"name": fields.String(required=True),
119+
"book": fields.Nested(
120+
Schema.from_dict({
121+
"author": fields.String(required=True),
122+
"title": fields.String(required=True),
123+
"cost": fields.Number(required=True)
124+
}),
125+
many=True
126+
)
127+
}),
128+
[
129+
{"name": "valid_1", "book": [{"author": "Sam", "title": "Sam's Book", "cost": "32.5"}]},
130+
],
131+
[
132+
{"name": "valid_1", "book": [{"author": "Sam", "title": "Sam's Book", "cost": 32.5}]},
133+
],
134+
[]
113135
)
114136
])
115137
def test_validate_df(spark_session, schema, input_data, valid_rows, invalid_rows):

0 commit comments

Comments
 (0)