Skip to content

Commit

Permalink
Fixed an issue with OneHotEncoderOp when computing categorySizes (#873)
Browse files Browse the repository at this point in the history
The spark feature org.apache.spark.ml.feature.OneHotEncoderModel has two mixins for the input columns: inputCol and inputCols. We need to check which param is set and use that correct one to compute categorySizes.
  • Loading branch information
ltrottier-yelp authored Jul 3, 2024
1 parent 8784e8e commit 43993e1
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,32 +42,83 @@ class OneHotEncoderOp extends SimpleSparkOp[OneHotEncoderModel] {
override def store(model: Model, obj: OneHotEncoderModel)
(implicit context: BundleContext[SparkBundleContext]): Model = {
assert(context.context.dataset.isDefined, BundleHelper.sampleDataframeMessage(klazz))

assert(!(obj.isSet(obj.inputCol) && obj.isSet(obj.inputCols)), "OneHotEncoderModel cannot have both inputCol and inputCols set")
assert(!(obj.isSet(obj.outputCol) && obj.isSet(obj.outputCols)), "OneHotEncoderModel cannot have both outputCol and outputCols set")
val inputCols = if (obj.isSet(obj.inputCol)) Array(obj.getInputCol) else obj.getInputCols
val df = context.context.dataset.get
val categorySizes = obj.getInputCols.map { f OneHotEncoderOp.sizeForField(df.schema(f)) }

model.withValue("category_sizes", Value.intList(categorySizes))
val categorySizes = inputCols.map { f OneHotEncoderOp.sizeForField(df.schema(f)) }
var m = model.withValue("category_sizes", Value.intList(categorySizes))
.withValue("drop_last", Value.boolean(obj.getDropLast))
.withValue("handle_invalid", Value.string(obj.getHandleInvalid))

if (obj.isSet(obj.inputCol)) {
m = m.withValue("inputCol", Value.string(obj.getInputCol))
}
if (obj.isSet(obj.inputCols)) {
m = m.withValue("inputCols", Value.stringList(obj.getInputCols))
}
if (obj.isSet(obj.outputCol)) {
m = m.withValue("outputCol", Value.string(obj.getOutputCol))
}
if (obj.isSet(obj.outputCols)) {
m = m.withValue("outputCols", Value.stringList(obj.getOutputCols))
}
m
}

override def load(model: Model)
(implicit context: BundleContext[SparkBundleContext]): OneHotEncoderModel = {
new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray)
.setDropLast(model.value("drop_last").getBoolean)
.setHandleInvalid(model.value("handle_invalid").getString)
val m = new OneHotEncoderModel(uid = "", categorySizes = model.value("category_sizes").getIntList.toArray)
.setDropLast(model.value("drop_last").getBoolean)
.setHandleInvalid(model.value("handle_invalid").getString)
if (model.getValue("inputCol").isDefined) {
m.setInputCol(model.value("inputCol").getString)
}
if (model.getValue("inputCols").isDefined) {
m.setInputCols(model.value("inputCols").getStringList.toArray)
}
if (model.getValue("outputCol").isDefined) {
m.setOutputCol(model.value("outputCol").getString)
}
if (model.getValue("outputCols").isDefined) {
m.setOutputCols(model.value("outputCols").getStringList.toArray)
}
m
}
}

override def sparkLoad(uid: String, shape: NodeShape, model: OneHotEncoderModel): OneHotEncoderModel = {
new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes)
val m = new OneHotEncoderModel(uid = uid, categorySizes = model.categorySizes)
.setDropLast(model.getDropLast)
.setHandleInvalid(model.getHandleInvalid)
if (model.isSet(model.inputCol)) {
m.setInputCol(model.getInputCol)
}
if (model.isSet(model.inputCols)) {
m.setInputCols(model.getInputCols)
}
if (model.isSet(model.outputCol)) {
m.setOutputCol(model.getOutputCol)
}
if (model.isSet(model.outputCols)) {
m.setOutputCols(model.getOutputCols)
}
m
}

override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("input", obj.inputCols))
override def sparkInputs(obj: OneHotEncoderModel): Seq[ParamSpec] = {
obj.isSet(obj.inputCol) match {
case true => Seq(ParamSpec("input", obj.inputCol))
case false => Seq(ParamSpec("input", obj.inputCols))
}

}

override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = Seq(ParamSpec("output", obj.outputCols))
override def sparkOutputs(obj: OneHotEncoderModel): Seq[ParamSpec] = {
obj.isSet(obj.outputCol) match {
case true => Seq(ParamSpec("output", obj.outputCol))
case false => Seq(ParamSpec("output", obj.outputCols))
}

}

}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.apache.spark.ml.parity.feature

import org.apache.spark.ml.bundle.SparkBundleContext
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
import org.apache.spark.ml.{Pipeline, Transformer}
Expand All @@ -22,4 +23,40 @@ class OneHotEncoderParitySpec extends SparkParityBase {
.fit(dataset)

override val unserializedParams = Set("stringOrderType")

it("serializes/deserializes the Spark model properly with one in/out column"){
bundleCache = None
val additionalIgnoreParams = Set("outputCol")
val pipeline = new Pipeline()
.setStages(Array(
new StringIndexer().setInputCol("state").setOutputCol("state_index"),
new OneHotEncoder().setInputCol("state_index").setOutputCol("state_oh")
)).fit(dataset)
val sparkTransformed = pipeline.transform(dataset)
implicit val sbc = SparkBundleContext().withDataset(sparkTransformed)
val deserializedTransformer = deserializedSparkTransformer(pipeline)
checkEquality(pipeline, deserializedTransformer, additionalIgnoreParams)
equalityTest(sparkTransformed, deserializedTransformer.transform(dataset))
bundleCache = None
}

it("fails to instantiate if the Spark model sets inputCol and inputCols"){
intercept[IllegalArgumentException] {
new OneHotEncoder()
.setInputCol("state")
.setInputCols(Array("state_index", "state_index2"))
.setOutputCols(Array("state_oh", "state_oh2"))
.fit(dataset)
}
}

it("fails to instantiate if the Spark model sets outputCol and outputCols"){
intercept[IllegalArgumentException] {
new OneHotEncoder()
.setInputCol("state")
.setOutputCol("state_oh")
.setOutputCols(Array("state_oh", "state_oh2"))
.fit(dataset)
}
}
}

0 comments on commit 43993e1

Please sign in to comment.