GeaFlow提供了近线模型推理能力,只需要用户提供模型调用的python文件,不需要将模型转换成onnx模型,避免了模型转换带来的性能下降, 同时降低了算法开发人员的部署难度。这个例子展示了如何使用GeaFlow在近线计算的过程中调用AI模型进行推理并获取模型推理结果。 本例子中的AI模型是一个在GNN常用的数据集Cora上训练的图节点分类模型,GeaFlow读取节点id构建vertex,然后在近线计算的过程中, 将节点id发送给python的模型推理进程,调用AI模型推理得到节点预测的类型以及相应的概率,然后将结果返回给GeaFlow的java进程。 本例子展示了如何通过GeaFlow进行模型推理,在真实场景中近线计算的逻辑可能会更加复杂,模型推理只是近线计算的一个步骤, 在获取模型结果之后可以进行复杂的迭代计算,甚至多次调用推理模型,可以根据需要进行扩展。
udf的项目结构如下
package org.example; import org.apache.geaflow.api.function.io.SinkFunction; import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; import org.apache.geaflow.api.graph.function.vc.base.IncGraphInferContext; import org.apache.geaflow.api.pdata.stream.window.PWindowSource; import org.apache.geaflow.api.window.impl.SizeTumblingWindow; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.env.Environment; import org.apache.geaflow.env.EnvironmentFactory; import org.apache.geaflow.example.function.FileSink; import org.apache.geaflow.example.function.FileSource; import org.apache.geaflow.file.FileConfigKeys; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.edge.impl.ValueEdge; import org.apache.geaflow.model.graph.meta.GraphMetaType; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.model.graph.vertex.impl.ValueVertex; import org.apache.geaflow.pipeline.IPipelineResult; import org.apache.geaflow.pipeline.Pipeline; import org.apache.geaflow.pipeline.PipelineFactory; import org.apache.geaflow.pipeline.task.IPipelineTaskContext; import org.apache.geaflow.pipeline.task.PipelineTask; import org.apache.geaflow.view.GraphViewBuilder; import org.apache.geaflow.view.IViewDesc.BackendType; import org.apache.geaflow.view.graph.GraphViewDesc; import org.apache.geaflow.view.graph.PGraphView; import org.apache.geaflow.view.graph.PIncGraphView; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; public class IncrGraphInferCompute { private static final Logger LOGGER = LoggerFactory.getLogger(IncrGraphInferCompute.class); // Set result dir. public static final String RESULT_FILE_PATH = "/tmp/geaflow"; public static final String INFER_PYTHON_CLASS_NAME = "myTransFormFunction"; public static void main(String[] args) { Map<String, String> config = new HashMap<>(); config.put(ExecutionConfigKeys.JOB_APP_NAME.getKey(), IncrGraphInferCompute.class.getSimpleName()); config.put(FileConfigKeys.ROOT.getKey(), "/tmp/"); Environment environment = EnvironmentFactory.onLocalEnvironment(args); Configuration configuration = environment.getEnvironmentContext().getConfig(); configuration.putAll(config); IPipelineResult result = submit(environment); result.get(); } public static IPipelineResult<?> submit(Environment environment) { final Pipeline pipeline = PipelineFactory.buildPipeline(environment); Configuration envConfig = environment.getEnvironmentContext().getConfig(); envConfig.put(FrameworkConfigKeys.INFER_ENV_ENABLE, "true"); envConfig.put(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME, INFER_PYTHON_CLASS_NAME); envConfig.put(FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC, "1800"); envConfig.put(FrameworkConfigKeys.INFER_ENV_CONDA_URL, "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-aarch64.sh"); envConfig.put(FileSink.OUTPUT_DIR, RESULT_FILE_PATH); //build graph view final String graphName = "graph_view_name"; GraphViewDesc graphViewDesc = GraphViewBuilder.createGraphView(graphName) .withShardNum(1) .withBackend(BackendType.RocksDB) .withSchema(new GraphMetaType(IntegerType.INSTANCE, ValueVertex.class, Integer.class, ValueEdge.class, IntegerType.class)) .build(); pipeline.withView(graphName, graphViewDesc); pipeline.submit(new PipelineTask() { @Override public void execute(IPipelineTaskContext pipelineTaskCxt) { Configuration conf = pipelineTaskCxt.getConfig(); PWindowSource<IVertex<Integer, List<Object>>> vertices = // extract vertex from edge file pipelineTaskCxt.buildSource(new FileSource<>("data/Cora/node_ids.txt", line -> { String[] fields = line.split(","); IVertex<Integer, List<Object>> vertex = new ValueVertex<>( Integer.valueOf(fields[0]), null); return Arrays.asList(vertex); }), SizeTumblingWindow.of(10000)) .withParallelism(1); PWindowSource<IEdge<Integer, Integer>> edges = pipelineTaskCxt.buildSource(new org.apache.geaflow.example.function.FileSource<>("data/Cora/node_ids.txt", line -> { String[] fields = line.split(","); IEdge<Integer, Integer> edge = new ValueEdge<>(Integer.valueOf(fields[0]), Integer.valueOf(fields[0]), 1); return Collections.singletonList(edge); }), SizeTumblingWindow.of(5000)); PGraphView<Integer, List<Object>, Integer> fundGraphView = pipelineTaskCxt.getGraphView(graphName); PIncGraphView<Integer, List<Object>, Integer> incGraphView = fundGraphView.appendGraph(vertices, edges); int mapParallelism = 1; int sinkParallelism = 1; SinkFunction<String> sink = new FileSink<>(); incGraphView.incrementalCompute(new IncGraphAlgorithms(1)) .getVertices() .map(v -> String.format("%s,%s", v.getId(), v.getValue())) .withParallelism(mapParallelism) .sink(sink) .withParallelism(sinkParallelism); } }); return pipeline.execute(); } public static class IncGraphAlgorithms extends IncVertexCentricCompute<Integer, List<Object>, Integer, Integer> { public IncGraphAlgorithms(long iterations) { super(iterations); } @Override public IncVertexCentricComputeFunction<Integer, List<Object>, Integer, Integer> getIncComputeFunction() { return new InferVertexCentricComputeFunction(); } @Override public VertexCentricCombineFunction<Integer> getCombineFunction() { return null; } } public static class InferVertexCentricComputeFunction implements IncVertexCentricComputeFunction<Integer, List<Object>, Integer, Integer> { private IncGraphComputeContext<Integer, List<Object>, Integer, Integer> graphContext; private IncGraphInferContext<List<Object>> graphInferContext; @Override public void init(IncGraphComputeContext<Integer, List<Object>, Integer, Integer> graphContext) { this.graphContext = graphContext; this.graphInferContext = (IncGraphInferContext<List<Object>>) graphContext; } @Override public void evolve(Integer vertexId, TemporaryGraph<Integer, List<Object>, Integer> temporaryGraph) { long lastVersionId = 0L; IVertex<Integer, List<Object>> vertex = temporaryGraph.getVertex(); HistoricalGraph<Integer, List<Object>, Integer> historicalGraph = graphContext .getHistoricalGraph(); if (vertex == null) { vertex = historicalGraph.getSnapShot(lastVersionId).vertex().get(); } if (vertex != null) { // Call the AI model to predict the class to which the node belongs and the corresponding probability. List<Object> result = this.graphInferContext.infer(vertexId); // Sink result. graphContext.collect(vertex.withValue(result)); LOGGER.info("node-{} max prob: {}, predict class: {}", vertexId, result.get(0), result.get(1)); } } @Override public void compute(Integer vertexId, Iterator<Integer> messageIterator) { } @Override public void finish(Integer vertexId, MutableGraph<Integer, List<Object>, Integer> mutableGraph) { } } }
import abc from typing import Union, List import torch import ast from torch_geometric.datasets import Planetoid from gcn_model import GCN def safe_int(number): try: return int(number) except: return 0 def safe_float(number): try: return float(number) except: return 0.0 class TransFormFunction(abc.ABC): def __init__(self, input_size): self.input_size = input_size @abc.abstractmethod def load_model(self, *args): pass @abc.abstractmethod def transform_pre(self, *args) -> Union[torch.Tensor, List[torch.Tensor]]: pass @abc.abstractmethod def transform_post(self, *args): pass # User class need to inherit TransFormFunction. class myTransFormFunction(TransFormFunction): def __init__(self): super().__init__(1) print("init myTransFormFunction") self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.dataset = Planetoid(root='./data', name='Cora') self.data = self.dataset[0].to(self.device) self.load_model('model.pt') def load_model(self, model_path: str): model = GCN(self.dataset.num_node_features, self.dataset.num_classes).to(self.device) model.load_state_dict(torch.load(model_path)) model.eval() out = model(self.data) self.prob = torch.exp(out) # Define model infer logic. def transform_pre(self, *args): node_prob = self.prob[args[0]] max_prob, max_class = node_prob.max(dim=0) return [max_prob.item(), max_class.item()], [max_prob.item(), max_class.item()] def transform_post(self, res): return res
--index-url https://pypi.tuna.tsinghua.edu.cn/simple torch torchvision torchaudio torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <groupId>org.example</groupId> <artifactId>InferUDF</artifactId> <version>1.0-SNAPSHOT</version> <packaging>jar</packaging> <properties> <maven.compiler.source>8</maven.compiler.source> <maven.compiler.target>8</maven.compiler.target> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> </properties> <dependencies> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-api</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-pdata</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-cluster</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-on-local</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-pipeline</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-infer</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-operator</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-common</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> <dependency> <groupId>org.apache.geaflow</groupId> <artifactId>geaflow-examples</artifactId> <version>0.5.0-SNAPSHOT</version> </dependency> </dependencies> </project>