blob: b40486d6879992ee58d9c7973e7fc0d3cc358de8 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ignite.ml.inference.parser;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.ignite.ml.inference.Model;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
/**
* Base class for TensorFlow model parsers. Contains the logic that is common for models saved as "SavedModel" and as a
* simple graph.
*
* @param <I> Type of model input.
* @param <O> Type of model output.
*/
public abstract class TensorFlowBaseModelParser<I, O> implements ModelParser<I, O, Model<I, O>> {
/** */
private static final long serialVersionUID = 5574259553625871456L;
/** Map of input graph nodes (placeholders) and transformers that allow to transform input into tensor. */
private final Map<String, InputTransformer<I>> inputs = new HashMap<>();
/** List of output graph nodes. */
private List<String> outputNames;
/** Transformer that allows to transform tensors into output. */
private OutputTransformer<O> outputTransformer;
/** {@inheritDoc} */
@Override public Model<I, O> parse(byte[] mdl) {
return new TensorFlowInfModel(parseModel(mdl));
}
/**
* Parses model specified in serialized form as byte array.
*
* @param mdl Inference model in serialized form as byte array.
* @return TensorFlow session that encapsulates the TensorFlow graph parsed from serialized model.
*/
public abstract Session parseModel(byte[] mdl);
/**
* Setter that allows to specify additional input graph node and correspondent transformer that allows to transform
* input into tensor.
*
* @param name Name of the input graph node.
* @param transformer Transformer that allows to transform input into tensor.
* @return This instance.
*/
public TensorFlowBaseModelParser<I, O> withInput(String name, InputTransformer<I> transformer) {
if (inputs.containsKey(name))
throw new IllegalArgumentException("Inputs already contains specified name [name=" + name + "]");
inputs.put(name, transformer);
return this;
}
/**
* Setter that allows to specify output graph nodes and correspondent transformer that allow to transform tensors
* into output.
*
* @param names List of output graph node names.
* @param transformer Transformer that allow to transform tensors into output.
* @return This instance.
*/
public TensorFlowBaseModelParser<I, O> withOutput(List<String> names, OutputTransformer<O> transformer) {
if (outputNames != null || outputTransformer != null)
throw new IllegalArgumentException("Outputs already specified");
outputNames = names;
outputTransformer = transformer;
return this;
}
/**
* Input transformer that accepts input and transforms it into tensor.
*
* @param <I> Type of model input.
*/
@FunctionalInterface
public interface InputTransformer<I> extends Serializable {
/**
* Transforms input into tensor.
*
* @param input Input data.
* @return Tensor (transformed input data).
*/
public Tensor<?> transform(I input);
}
/**
* Output transformer that accepts tensors and transforms them into output.
*
* @param <O> Type of model output.
*/
@FunctionalInterface
public interface OutputTransformer<O> extends Serializable {
/**
* Transforms tensors into output.
*
* @param output Tensors.
* @return Output (transformed tensors).
*/
public O transform(Map<String, Tensor<?>> output);
}
/**
* TensorFlow inference model based on pre-loaded graph and created session.
*/
private class TensorFlowInfModel implements Model<I, O> {
/** TensorFlow session. */
private final Session ses;
/**
* Constructs a new instance of TensorFlow inference model.
*
* @param ses TensorFlow session.
*/
TensorFlowInfModel(Session ses) {
this.ses = ses;
}
/** {@inheritDoc} */
@Override public O predict(I input) {
Session.Runner runner = ses.runner();
runner = feedAll(runner, input);
runner = fetchAll(runner);
List<Tensor<?>> prediction = runner.run();
Map<String, Tensor<?>> collectedPredictionTensors = indexTensors(prediction);
return outputTransformer.transform(collectedPredictionTensors);
}
/**
* Feeds input into graphs input nodes using input transformers (see {@link #inputs}).
*
* @param runner TensorFlow session runner.
* @param input Input.
* @return TensorFlow session runner.
*/
private Session.Runner feedAll(Session.Runner runner, I input) {
for (Map.Entry<String, InputTransformer<I>> e : inputs.entrySet()) {
String opName = e.getKey();
InputTransformer<I> transformer = e.getValue();
runner = runner.feed(opName, transformer.transform(input));
}
return runner;
}
/**
* Specifies graph output nodes to be fetched using {@link #outputNames}.
*
* @param runner TensorFlow session runner.
* @return TensorFlow session runner.
*/
private Session.Runner fetchAll(Session.Runner runner) {
for (String e : outputNames)
runner.fetch(e);
return runner;
}
/**
* Indexes tensors fetched from graph using {@link #outputNames}.
*
* @param tensors List of fetched tensors.
* @return Map of tensor name as a key and tensor as a value.
*/
private Map<String, Tensor<?>> indexTensors(List<Tensor<?>> tensors) {
Map<String, Tensor<?>> collectedTensors = new HashMap<>();
Iterator<String> outputNamesIter = outputNames.iterator();
Iterator<Tensor<?>> tensorsIter = tensors.iterator();
while (outputNamesIter.hasNext() && tensorsIter.hasNext())
collectedTensors.put(outputNamesIter.next(), tensorsIter.next());
// We expect that output names and output tensors have the same size.
if (outputNamesIter.hasNext() || tensorsIter.hasNext())
throw new IllegalStateException("Outputs are incorrect");
return collectedTensors;
}
/** {@inheritDoc} */
@Override public void close() {
ses.close();
}
}
}