mlflow is a tool for experiment tracking and model registry. Below is an example script, which also uses Typer. This uses a local database to keep track of experiments and models, and the mlflow ui can be accessed using mlflow ui --port 8080 --backend-store-uri sqlite:///mlruns.db
. These can both be replaced with cloud servers.
from typing import Annotated
import mlflow
from mlflow.pyfunc import PythonModel, PythonModelContext
import numpy as np
from sklearn.model_selection import train_test_split
import typer
app = typer.Typer()
def train(X, y, model, loss_fn, metrics_fn, optimizer):
model.train()
for _ in range(len(y)):
pred = model(X)
loss = loss_fn(pred, y)
accuracy = metrics_fn(pred, y)
# backpropagation
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch % 100 == 0:
mlflow.log_metric("loss", loss.item(), step=(batch // 100))
@app.command()
def main(
seed: Annotated[int, typer.Option(help="Seed for random state.")] = 2,
test_size: Annotated[float, typer.Option(help="Proportion of the test dataset.")] = 0.2,
n_iterations: Annotated[int, typer.Option(help="Number of iterations to train the model.")] = 1000,
lr: Annotated[float, typer.Option(help="Learning rate.")] = 0.1,
):
uri: str = "sqlite:///mlruns.db"
experiment_name: str = "test_model"
run_name: str = "test_run"
X, y = load_data()
X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
dataset_params = {
"seed": seed,
"test_size": test_size,
}
model_params = {
"optimizer": "Adam",
"learning_rate": lr,
"n_iterations": n_iterations,
}
with mlflow.start_run(run_name=run_name):
mlflow.set_tag("info", "Some words")
mlflow.log_params(dataset_params)
mlflow.log_params(model_params)
# any `mlflow.log` statements needs to be within `with`
train(X_train, y_train, model, optimizer, n_iterations)
test_pred = model.predict(X_test)
mlflow.log_metric("Test MSE", mean_squared_error(test_pred, y_test))
# generic type, but most use cases have a log model (see docs)
mlflow.pyfunc.log_model(artifact_path="", python_model=model)
if __name__ == "__main__":
app()
There is in-built autologging support for scikit-learn, pytorch lightning, keras, etc. More general model support is available, with an example for pymc and for pytorch.