Skip to content

Commit

Permalink
Add SGDClassifier (#17)
Browse files Browse the repository at this point in the history
* Add SGDClassifier

* Fix complaining linter
  • Loading branch information
billfreeman44 authored Jan 1, 2023
1 parent 4505d10 commit bc71e73
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pureskillgg_dsdk/ds_models/s3_scikit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _use_model(self, dataframe):
if self._model_type == "MiniBatchKMeans":
labels = self._loaded_model.predict(dataframe)
return labels
if self._model_type == "SGDClassifier":
labels = self._loaded_model.predict_proba(dataframe)
return labels
raise Exception(f"Unknown model_type {self._model_type}")

def invoke(self, dataframe):
Expand Down

0 comments on commit bc71e73

Please sign in to comment.