| # Licensed to Apache Software Foundation (ASF) under one or more contributor |
| # license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright |
| # ownership. Apache Software Foundation (ASF) licenses this file to you under |
| # the Apache License, Version 2.0 (the "License"); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| class Tool: |
| model_path = None |
| |
| conda_env = { |
| "channels": ["defaults", "conda-forge"], |
| "dependencies": [ |
| "python=3.8.10", |
| { |
| "pip": [ |
| "mlflow", |
| "scikit-learn==0.24.2", |
| "boto3==1.22.2", |
| "pandas==1.3.5", |
| "setuptools<59.6.0", |
| ], |
| }, |
| ], |
| "name": "mlflow-env", |
| } |
| |
| @staticmethod |
| def train_automl(train_x, train_y, other_params=None, **kwargs): |
| raise NotImplementedError |
| |
| @staticmethod |
| def eval_automl(automl, test_x, test_y): |
| score = automl.score(test_x, test_y) |
| return {"score": score} |
| |
| @staticmethod |
| def save_automl(automl, save_path: str): |
| raise NotImplementedError |
| |
| |
| class BasePredictor: |
| def __init__(self, automl_path=None): |
| self.load_automl(automl_path) |
| |
| def predict(self, inputs): |
| return {} |
| |
| def load_automl(self, path): |
| ... |