| # Licensed to Apache Software Foundation (ASF) under one or more contributor |
| # license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright |
| # ownership. Apache Software Foundation (ASF) licenses this file to you under |
| # the Apache License, Version 2.0 (the "License"); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import logging |
| |
| import click |
| import mlflow |
| import mlflow.sklearn |
| |
| from core.data import load_data |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="[%(asctime)s] %(name)s %(levelname)s %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| def get_training_func(algorithm): |
| if algorithm == "svm": |
| from core.training.svm import train_svc as training_func |
| |
| elif algorithm == "lightgbm": |
| from core.training.lightgbm import train_lightgbm as training_func |
| |
| elif algorithm == "xgboost": |
| from core.training.xgboost import train_xgboost as training_func |
| |
| elif algorithm == "lr": |
| from core.training.lr import train_lr as training_func |
| |
| else: |
| assert f"{algorithm} not supported" |
| |
| return training_func |
| |
| |
| def create_model_version(model_name, key_metrics=None, run_id=None, auto_replace=True): |
| client = mlflow.tracking.MlflowClient() |
| filter_string = "name='{}'".format(model_name) |
| versions = client.search_model_versions(filter_string) |
| |
| if not versions: |
| client.create_registered_model(model_name) |
| |
| for version in versions: |
| if version.current_stage == "Production": |
| client.transition_model_version_stage( |
| model_name, version=version.version, stage="Archived" |
| ) |
| |
| if run_id: |
| uri = f"runs:/{run_id}/sklearn_model" |
| mv = mlflow.register_model(uri, model_name) |
| |
| if not key_metrics: |
| client.transition_model_version_stage( |
| model_name, version=mv.version, stage="Production" |
| ) |
| logger.info("register last version to Production") |
| |
| if key_metrics: |
| version2metrics = [] |
| versions = client.search_model_versions(filter_string) |
| for version in versions: |
| metrics = client.get_run(version.run_id).data.metrics[key_metrics] |
| version2metrics.append((version.version, metrics)) |
| |
| logger.info(f"version2metrics({key_metrics}): {version2metrics}") |
| |
| best_version = max(version2metrics, key=lambda x: x[1])[0] |
| |
| logger.info("register version: %s to Production", best_version) |
| client.transition_model_version_stage( |
| model_name, version=best_version, stage="Production" |
| ) |
| |
| return versions |
| |
| |
| @click.command() |
| @click.option("--algorithm") |
| @click.option("--data_path") |
| @click.option("--label_column", default="label") |
| @click.option("--model_name", default=None) |
| @click.option("--random_state", default=0) |
| @click.option("--param_file", default=None) |
| @click.option("--params", default=None) |
| @click.option("--search_params", default=None) |
| def main(algorithm, data_path, label_column, model_name, random_state, param_file, params, search_params): |
| |
| train_x, train_y, test_x, test_y = load_data( |
| data_path, label_column, random_state=random_state |
| ) |
| training_func = get_training_func(algorithm) |
| |
| with mlflow.start_run() as run: |
| model, metrics = training_func(train_x, |
| train_y, |
| test_x, |
| test_y, |
| param_file=param_file, |
| params=params, |
| search_params=search_params, |
| ) |
| print(metrics) |
| mlflow.log_metrics(metrics) |
| mlflow.sklearn.log_model(model, artifact_path="sklearn_model") |
| |
| if model_name: |
| create_model_version( |
| model_name, key_metrics='f1-score', run_id=run.info.run_id) |
| |
| |
| if __name__ == "__main__": |
| main() |