Use Scikit-learn with Metaflow
Question
I have a scikit-learn workflow that I want to incorporate into a Metaflow flow. How can I include model fitting, prediction, feature transformations, and other capabilities enabled by scikit-learn in flow steps?
Solution
Note that this example uses a random forest classifier but the following applies to all scikit-learn models.
To turn this into a Metaflow flow, you first need to decide what your steps are going to be. In this case, there are distinct steps to:
- Load data.
- Instantiate a model.
- Train a model with cross-validation.
1Estimators to Flows
In general, this involves some design choices and we have some rules of thumb here. A benefit of separating flows into Metaflow steps is that you can resume failed computation from any step without having to recompute everything prior to the failed step which makes development much faster.
2Run Flow
This flow shows how to:
- Import
FlowSpec
andstep
. - Include step-specific imports within each step.
- Assign any data structures you wish to pass between steps to self.
- Train a model and apply cross validation to evaluate it.
from metaflow import FlowSpec, step
class SklearnFlow(FlowSpec):
@step
def start(self):
from sklearn import datasets
self.iris = datasets.load_iris()
self.X = self.iris['data']
self.y = self.iris['target']
self.next(self.rf_model)
@step
def rf_model(self):
from sklearn.ensemble import RandomForestClassifier
self.clf = RandomForestClassifier(
n_estimators=10,
max_depth=None,
min_samples_split=2,
random_state=0
)
self.next(self.train)
@step
def train(self):
from sklearn.model_selection import cross_val_score
self.scores = cross_val_score(self.clf, self.X,
self.y, cv=5)
self.next(self.end)
@step
def end(self):
print("SklearnFlow is all done.")
if __name__ == "__main__":
SklearnFlow()
The example shows how to use the --with card
CLI option to use a Metaflow card
which produces HTML visualizations.
python fit_sklearn_estimator.py run --with card
3View Card
Now you can view the card for the train
step using this command:
python fit_sklearn_estimator.py card view train