Track Artifacts with Weights and Biases
Question
How can I track artifacts of my flows with Weights and Biases?
Solution
You can track flow artifacts using any Weights and Biases calls you already use. This can be especially useful if you want to track artifacts during the lifecycle of long-running tasks.
1Login to Weights and Biases
To run this code first go sign up for a Weights and Biases account and make sure you have logged in with your API key. It is recommended that you store the key as an environment variable. In the example shown later, the Weights and Biases "entity" and "project" are also stored as environment variables:
export WANDB_API_KEY=<YOUR KEY>
export WANDB_ENTITY=<YOUR USERNAME>
export WANDB_PROJECT=<YOUR PROJECT>
Then you can install and log in to the Weights and Biases Python client:
pip install wandb
If you don't set the WANDB_API_KEY
environment variable, you will need to paste your key after running:
wandb login
2Define Logging Function
Here is a function that takes in arguments from a dataset and classification model and logs with Weights and Biases. It uses Weights and Biases Scikit-learn integration, but you can replace it with arbitrary logging functions relevant to your workflow.
import os
import wandb
def plot_results(X_train, y_train, X_test, y_test,
y_pred, y_probs, clf, labels):
wandb.init(entity=os.getenv("WANDB_ENTITY"), project=os.getenv("WANDB_PROJECT"))
wandb.sklearn.plot_class_proportions(y_train,
y_test,
labels)
wandb.sklearn.plot_learning_curve(clf,
X_train,
y_train)
wandb.sklearn.plot_roc(y_test, y_probs, labels)
wandb.sklearn.plot_precision_recall(y_test,
y_probs,
labels)
wandb.sklearn.plot_feature_importances(clf)
wandb.sklearn.plot_classifier(
clf, X_train, X_test, y_train, y_test, y_pred,
y_probs, labels, is_binary=True,
model_name='RandomForest'
)
wandb.finish()
3Run Flow
The flow shows how to:
- Load data in the
start
step. - Build a model and call a custom logging function in the
model
step.- Call the custom logging function
plot_results
to - This step uses Metaflow's
@environment
decorator to pass environment variables relevant to Weights and Biases into the step. This is useful when you want to track a step run on a remote machine using a Metaflow decorator like@batch
or@kubernetes
.
- Call the custom logging function
from metaflow import FlowSpec, step, environment, batch, conda_base
import os
import wandb
from wandb_helpers import plot_results
@conda_base(libraries={"wandb": "0.12.15", "scikit-learn": "1.0.2", "pandas": "1.4.2"})
class TrackPlotsFlow(FlowSpec):
@step
def start(self):
from sklearn import datasets
from sklearn.model_selection import train_test_split
self.iris = datasets.load_iris()
self.X = self.iris['data']
self.y = self.iris['target']
self.labels = self.iris['target_names']
split = train_test_split(self.X, self.y,
test_size=0.2)
self.X_train = split[0]
self.X_test = split[1]
self.y_train = split[2]
self.y_test = split[3]
self.next(self.model)
# Copy env vars to tasks on a different machine.
@environment(vars={
"WANDB_NAME": "Plot RandomForestClassifier",
"WANDB_API_KEY": os.getenv("WANDB_API_KEY"),
"WANDB_PROJECT": os.getenv("WANDB_PROJECT"),
"WANDB_ENTITY": os.getenv("WANDB_ENTITY"),
})
@batch(cpu=2)
@step
def model(self):
from sklearn.ensemble import RandomForestClassifier
self.clf = RandomForestClassifier(
n_estimators=10, max_depth=None,
min_samples_split=2, random_state=0
)
from sklearn.model_selection import cross_val_score
self.clf.fit(self.X_train, self.y_train)
self.y_pred = self.clf.predict(self.X_test)
self.y_probs = self.clf.predict_proba(
self.X_test
)
plot_results(self.X_train, self.y_train,
self.X_test, self.y_test,
self.y_pred, self.y_probs,
self.clf, self.labels)
self.next(self.end)
@step
def end(self):
print("Flow is all done.")
if __name__ == "__main__":
TrackPlotsFlow()
python track_with_wandb_custom.py --environment=conda run