Add and Remove Tags
Question
How can I add and remove tags to flow runs programmatically?
Solution
You can do this within a flow or using the Client API.
1Run Flow
This flow shows how to:
- Load a dataset.
- Train a Scikit-learn model.
- Evaluate the model on a test set.
- Tag the model as a
production_candidate
if the model score is greater thanaccuracy_threshold
.
- Tag the model as a
add_remove_tags_programmatically.py
from metaflow import FlowSpec, step, Flow, current, Parameter
class ModelTaggingFlow(FlowSpec):
max_depth = Parameter('max-depth', default=2)
tag_msg = 'Tagging run {} as a promising model'
accuracy_threshold = 0.85
@step
def start(self):
from sklearn import datasets
from sklearn.model_selection import train_test_split
data = datasets.load_wine()
data = train_test_split(data['data'],
data['target'],
random_state = 42)
self.X_train = data[0]
self.X_test = data[1]
self.y_train = data[2]
self.y_test = data[3]
self.next(self.train)
@step
def train(self):
from sklearn.tree import DecisionTreeClassifier
self.params = {
'max_leaf_nodes': None,
'max_depth': self.max_depth,
'max_features' : 'sqrt',
'random_state': 0
}
self.model = DecisionTreeClassifier(**self.params)
self.model.fit(self.X_train, self.y_train)
self.next(self.eval_and_tag)
@step
def eval_and_tag(self):
from sklearn.metrics import (accuracy_score,
classification_report)
self.pred = self.model.predict(self.X_test)
self.accuracy = float(
accuracy_score(self.y_test, self.pred))
print(self.accuracy)
if self.accuracy > self.accuracy_threshold:
print(self.tag_msg.format(current.run_id))
run = Flow(current.flow_name)[current.run_id]
run.add_tag('promising model')
self.next(self.end)
@step
def end(self):
pass
if __name__ == '__main__':
ModelTaggingFlow()
python add_remove_tags_programmatically.py run
python add_remove_tags_programmatically.py run --max-depth 6
2Observe Model Scores
You can use the client API to get the latest flow runs. Here is a way to list the accuracy
value of each Run
.
from metaflow import Flow
flow = Flow('ModelTaggingFlow')
tag = 'promising model'
runs = list(flow.runs(tag))
print("All models tagged with `{}`:".format(tag))
for run in runs:
acc = round(100 * run.data.accuracy, 2)
print("\tRun {}: {}% Accuracy".format(run.id, acc))
3Update Tags Using the Client API
You can use the run.add_tag
, run.remove_tag
or run.replace_tag
functions to change a Run
tag.
These lines will add the production candidate
tag for each promising model
with an 87% accuracy score.
flow = Flow('ModelTaggingFlow')
runs = list(flow.runs('promising model'))
for run in runs:
if run.data.accuracy > .87:
run.add_tag('production candidate')
Now you can see the model accuracy only for these models. This can be a useful pattern when reviewing models or testing and promoting them to production.
flow = Flow('ModelTaggingFlow')
tag = 'production candidate'
runs = list(flow.runs(tag))
print("All models tagged `{}`:".format(tag))
for run in runs:
acc = round(100 * run.data.accuracy, 2)
print("\tRun {}: {}% Accuracy".format(run.id, acc))