blob: de499a3f7b12697bd24e4eef962817eff2f6518b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.dl.imagerec;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder;
import org.json.simple.JSONArray;
import org.json.simple.JSONObject;
import org.json.simple.parser.JSONParser;
import org.json.simple.parser.ParseException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.ContentHandler;
import org.xml.sax.SAXException;
import org.apache.tika.config.Field;
import org.apache.tika.config.InitializableProblemHandler;
import org.apache.tika.config.Param;
import org.apache.tika.exception.TikaConfigException;
import org.apache.tika.exception.TikaException;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.mime.MediaType;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.parser.recognition.ObjectRecogniser;
import org.apache.tika.parser.recognition.RecognisedObject;
/**
* {@link DL4JInceptionV3Net} is an implementation of {@link ObjectRecogniser}.
* This object recogniser is powered by <a href="http://deeplearning4j.org">Deeplearning4j</a>.
* This implementation is pre configured to use <a href="https://arxiv.org/abs/1512.00567">
* Google's InceptionV3 model </a> pre trained on
* ImageNet corpus. The models references in default settings are originally trained and exported
* from <a href="http://keras.io">Keras </a> and imported using DL4J's importer tools.
* <p>
* Although this implementation is made to work out of the box without user attention,
* for advances users who are interested in tweaking the settings, the following fields are
* configurable:
* <ul>
* <li>{@link #modelWeightsPath}</li>
* <li>{@link #labelFile}</li>
* <li>{@link #labelLang}</li>
* <li>{@link #cacheDir}</li>
* <li>{@link #imgWidth}</li>
* <li>{@link #imgHeight}</li>
* <li>{@link #imgChannels}</li>
* <li>{@link #minConfidence}</li>
* </ul>
* </p>
*
* @see ObjectRecogniser
* @see org.apache.tika.parser.recognition.ObjectRecognitionParser
* @see org.apache.tika.parser.recognition.tf.TensorflowImageRecParser
* @see org.apache.tika.parser.recognition.tf.TensorflowRESTRecogniser
* @since Tika 1.15
*/
public class DL4JInceptionV3Net implements ObjectRecogniser {
private static final Set<MediaType> MEDIA_TYPES =
Collections.singleton(MediaType.image("jpeg"));
private static final Logger LOG = LoggerFactory.getLogger(DL4JInceptionV3Net.class);
private static final String DEF_WEIGHTS_URL =
"https://github.com/USCDataScience/tika-dockers/releases/download/v0.2/inception_v3_keras_2.h5";
private static final String DEF_LABEL_MAPPING_URL =
"https://github.com/USCDataScience/tika-dockers/releases/download/v0.2/imagenet_class_index.json";
private static final String BASE_DIR =
System.getProperty("user.home") + File.separator + ".tika-dl" + File.separator +
"models" + File.separator + "keras";
private static final String MODEL_DIR = BASE_DIR + File.separator + "inception-v3";
/**
* Cache dir to be used for downloading the weights file.
* This is used to download the model.
*/
@Field
private File cacheDir = new File(MODEL_DIR);
/**
* Path to a HDF5 file that contains weights of the Keras network
* that was obtained by training the network on a labelled dataset.
* <br/>
* Note: when the value is set to &lt;download&gt;, the default model will be
* downloaded from {@value #DEF_WEIGHTS_URL}
*/
@Field
private String modelWeightsPath = DEF_WEIGHTS_URL;
/***
* Path to file that tells how to map node index to human readable label names
* <br/>
* The default is retrieved from {@value DEF_LABEL_MAPPING_URL}
*/
@Field
private String labelFile = DEF_LABEL_MAPPING_URL;
/**
* Language name of the labels.
* <br/>
* Default is 'en'
*/
@Field
private String labelLang = "en";
@Field
private int imgHeight = 299;
@Field
private int imgWidth = 299;
@Field
private int imgChannels = 3;
/***
* Ignores the labels that are below this confidence score
*/
@Field
private double minConfidence = 0.005;
private ComputationGraph graph;
private NativeImageLoader imageLoader;
private Map<Integer, String> labelMap;
private static synchronized File cachedDownload(File cacheDir, URI uri) throws IOException {
if ("file".equals(uri.getScheme()) || uri.getScheme() == null) {
return new File(uri);
}
if (!cacheDir.exists()) {
cacheDir.mkdirs();
}
String[] parts = uri.toASCIIString().split("/");
File cacheFile = new File(cacheDir, parts[parts.length - 1]);
File successFlag = new File(cacheFile.getAbsolutePath() + ".success");
if (cacheFile.exists() && successFlag.exists()) {
LOG.info("Cache exist at {}. Not downloading it", cacheFile.getAbsolutePath());
} else {
if (successFlag.exists()) {
successFlag.delete();
}
LOG.info("Cache doesn't exist. Going to make a copy");
LOG.info("This might take a while! GET {}", uri);
FileUtils.copyURLToFile(uri.toURL(), cacheFile, 5000, 60000);
//restore the success flag again
FileUtils.write(successFlag, "CopiedAt:" + System.currentTimeMillis(),
Charset.defaultCharset());
}
return cacheFile;
}
@Override
public Set<MediaType> getSupportedMimes() {
return MEDIA_TYPES;
}
/***
*
* @param path path to resolve the file
* @return File or null
*/
private File retrieveFile(String path) {
File file = new File(path);
if (!file.exists()) {
LOG.warn("File {} not found in local file system." + " Asking the classloader", path);
URL url = getClass().getClassLoader().getResource(path);
if (url == null) {
LOG.debug("Classloader does not know the file {}", path);
file = null;
} else {
LOG.debug("Classloader knows the file {}", path);
try {
file = cachedDownload(cacheDir, url.toURI());
} catch (URISyntaxException | IOException e) {
LOG.warn(e.getMessage(), e);
}
}
}
return file;
}
private InputStream retrieveResource(String path) throws FileNotFoundException {
File file = new File(path);
if (file.exists()) {
return new FileInputStream(file);
}
LOG.warn("File {} not found in local file system. Asking the classloader", path);
return getClass().getClassLoader().getResourceAsStream(path);
}
private String mayBeDownloadFile(String path) throws TikaConfigException {
String resolvedFilePath;
if (path.startsWith("http://") || path.startsWith("https://")) {
LOG.debug("Config instructed to download the file, doing so.");
try {
resolvedFilePath = cachedDownload(cacheDir, URI.create(path)).getAbsolutePath();
} catch (IOException e) {
throw new TikaConfigException(e.getMessage(), e);
}
} else {
File file = retrieveFile(path);
if (!file.exists()) {
LOG.error("File does not exist at :: {}", path);
}
resolvedFilePath = file.getAbsolutePath();
}
return resolvedFilePath;
}
@Override
public void initialize(Map<String, Param> params) throws TikaConfigException {
//STEP 1: resolve weights file, download if necessary
modelWeightsPath = mayBeDownloadFile(modelWeightsPath);
//STEP 2: Load labels map
try (InputStream stream = retrieveResource(mayBeDownloadFile(labelFile))) {
this.labelMap = loadClassIndex(stream);
} catch (IOException | ParseException e) {
LOG.error("Could not load labels map", e);
return;
}
//STEP 3: initialize the graph
try {
this.imageLoader = new NativeImageLoader(imgHeight, imgWidth, imgChannels);
LOG.info("Going to load Inception network...");
long st = System.currentTimeMillis();
KerasModelBuilder builder =
new KerasModel().modelBuilder().modelHdf5Filename(modelWeightsPath)
.enforceTrainingConfig(false);
builder.inputShape(new int[]{imgHeight, imgWidth, 3});
KerasModel model = builder.buildModel();
this.graph = model.getComputationGraph();
long time = System.currentTimeMillis() - st;
LOG.info("Loaded the Inception model. Time taken={}ms", time);
} catch (IOException | InvalidKerasConfigurationException |
UnsupportedKerasConfigurationException e) {
throw new TikaConfigException(e.getMessage(), e);
}
}
@Override
public void checkInitialization(InitializableProblemHandler problemHandler)
throws TikaConfigException {
//TODO: what do we want to check here?
}
@Override
public boolean isAvailable() {
return graph != null;
}
/**
* Pre process image to reduce to make it feedable to inception network
*
* @param input Input image
* @return processed image
*/
public INDArray preProcessImage(INDArray input) {
// Transform to [-1.0, 1.0] range
return input.div(255.0).sub(0.5).mul(2.0);
}
/**
* Loads the class to
*
* @param stream label index stream
* @return Map of integer -&gt; label name
* @throws IOException when the stream breaks unexpectedly
* @throws ParseException when the input doesn't contain a valid JSON map
*/
public Map<Integer, String> loadClassIndex(InputStream stream)
throws IOException, ParseException {
String content = IOUtils.toString(stream, StandardCharsets.UTF_8);
JSONObject jIndex = (JSONObject) new JSONParser().parse(content);
Map<Integer, String> classMap = new HashMap<>();
for (Object key : jIndex.keySet()) {
JSONArray names = (JSONArray) jIndex.get(key);
classMap.put(Integer.parseInt(key.toString()), names.get(names.size() - 1).toString());
}
return classMap;
}
@Override
public List<RecognisedObject> recognise(InputStream stream, ContentHandler handler,
Metadata metadata, ParseContext context)
throws IOException, SAXException, TikaException {
INDArray image = preProcessImage(imageLoader.asMatrix(stream));
INDArray scores = graph.outputSingle(image);
List<RecognisedObject> result = new ArrayList<>();
for (int i = 0; i < scores.length(); i++) {
if (scores.getDouble(i) > minConfidence) {
String label = labelMap.get(i);
String id = i + "";
result.add(new RecognisedObject(label, labelLang, id, scores.getDouble(i)));
LOG.debug("Found Object {}", label);
}
}
return result;
}
}