blob: 5a8037ec75166a0b82d7eb42d7b432a5436a88a1 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from typing import Dict, Union, Optional
from pyflink.common.config_options import ConfigOption
from pyflink.java_gateway import get_gateway
from pyflink.table.schema import Schema
__all__ = ['ModelDescriptor']
class ModelDescriptor(object):
"""
Describes a CatalogModel representing a model.
ModelDescriptor is a template for creating a CatalogModel instance. It closely resembles the
"CREATE MODEL" SQL DDL statement, containing input schema, output schema, and other
characteristics.
This can be used to register a model in the Table API.
"""
def __init__(self, j_model_descriptor):
self._j_model_descriptor = j_model_descriptor
@staticmethod
def for_provider(provider: str) -> 'ModelDescriptor.Builder':
"""
Creates a new :class:`~pyflink.table.ModelDescriptor.Builder` for a model using the given
provider value.
:param provider: The provider value for the model.
"""
gateway = get_gateway()
j_builder = gateway.jvm.ModelDescriptor.forProvider(provider)
return ModelDescriptor.Builder(j_builder)
def get_options(self) -> Dict[str, str]:
return self._j_model_descriptor.getOptions()
def get_input_schema(self) -> Optional[Schema]:
j_input_schema = self._j_model_descriptor.getInputSchema()
if j_input_schema.isPresent():
return Schema(j_input_schema.get())
else:
return None
def get_output_schema(self) -> Optional[Schema]:
j_output_schema = self._j_model_descriptor.getOutputSchema()
if j_output_schema.isPresent():
return Schema(j_output_schema.get())
else:
return None
def get_comment(self) -> Optional[str]:
j_comment = self._j_model_descriptor.getComment()
if j_comment.isPresent():
return j_comment.get()
else:
return None
def __str__(self):
return self._j_model_descriptor.toString()
def __eq__(self, other):
return (self.__class__ == other.__class__ and
self._j_model_descriptor.equals(other._j_model_descriptor))
def __hash__(self):
return self._j_model_descriptor.hashCode()
class Builder(object):
"""
Builder for ModelDescriptor.
"""
def __init__(self, j_builder):
self._j_builder = j_builder
def input_schema(self, input_schema: Schema) -> 'ModelDescriptor.Builder':
"""
Define the input schema of the ModelDescriptor.
"""
self._j_builder.inputSchema(input_schema._j_schema)
return self
def output_schema(self, output_schema: Schema) -> 'ModelDescriptor.Builder':
"""
Define the output schema of the ModelDescriptor.
"""
self._j_builder.outputSchema(output_schema._j_schema)
return self
def option(self, key: Union[str, ConfigOption], value) -> 'ModelDescriptor.Builder':
"""
Sets the given option on the model.
Option keys must be fully specified.
Example:
::
>>> ModelDescriptor.for_connector("OPENAI")\
... .input_schema(input_schema)\
... .output_schema(output_schema)\
... .option("task", "regression")\
... .build()
"""
if isinstance(key, str):
self._j_builder.option(key, value)
else:
self._j_builder.option(key._j_config_option, value)
return self
def comment(self, comment: str) -> 'ModelDescriptor.Builder':
"""
Define the comment for this model.
"""
self._j_builder.comment(comment)
return self
def build(self) -> 'ModelDescriptor':
"""
Returns an immutable instance of :class:`~pyflink.table.ModelDescriptor`.
"""
return ModelDescriptor(self._j_builder.build())