blob: 52941fe641f51d4302585ad99ced14ed3795089c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.postag;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Collections;
import java.util.HashSet;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import opennlp.tools.dictionary.Dictionary;
import opennlp.tools.ml.model.AbstractModel;
import opennlp.tools.namefind.TokenNameFinderFactory;
import opennlp.tools.util.BaseToolFactory;
import opennlp.tools.util.InvalidFormatException;
import opennlp.tools.util.SequenceValidator;
import opennlp.tools.util.Version;
import opennlp.tools.util.ext.ExtensionLoader;
import opennlp.tools.util.featuregen.AdaptiveFeatureGenerator;
import opennlp.tools.util.featuregen.AggregatedFeatureGenerator;
import opennlp.tools.util.featuregen.GeneratorFactory;
import opennlp.tools.util.model.ArtifactSerializer;
import opennlp.tools.util.model.UncloseableInputStream;
/**
* The factory that provides POS Tagger default implementations and resources
*/
public class POSTaggerFactory extends BaseToolFactory {
private static final String TAG_DICTIONARY_ENTRY_NAME = "tags.tagdict";
private static final String NGRAM_DICTIONARY_ENTRY_NAME = "ngram.dictionary";
protected Dictionary ngramDictionary;
private byte[] featureGeneratorBytes;
private Map<String, Object> resources;
protected TagDictionary posDictionary;
/**
* Creates a {@link POSTaggerFactory} that provides the default implementation
* of the resources.
*/
public POSTaggerFactory() {
}
/**
* Creates a {@link POSTaggerFactory}. Use this constructor to
* programmatically create a factory.
*
* @param ngramDictionary
* @param posDictionary
*
* @deprecated this constructor is here for backward compatibility and
* is not functional anymore in the training of 1.8.x series models
*/
@Deprecated
public POSTaggerFactory(Dictionary ngramDictionary, TagDictionary posDictionary) {
this.init(ngramDictionary, posDictionary);
// TODO: This could be made functional by creating some default feature generation
// which uses the dictionary ...
}
public POSTaggerFactory(byte[] featureGeneratorBytes, final Map<String, Object> resources,
TagDictionary posDictionary) {
this.featureGeneratorBytes = featureGeneratorBytes;
if (this.featureGeneratorBytes == null) {
this.featureGeneratorBytes = loadDefaultFeatureGeneratorBytes();
}
this.resources = resources;
this.posDictionary = posDictionary;
}
@Deprecated // will be removed when only 8 series models are supported
protected void init(Dictionary ngramDictionary, TagDictionary posDictionary) {
this.ngramDictionary = ngramDictionary;
this.posDictionary = posDictionary;
}
protected void init(byte[] featureGeneratorBytes, final Map<String, Object> resources,
TagDictionary posDictionary) {
this.featureGeneratorBytes = featureGeneratorBytes;
this.resources = resources;
this.posDictionary = posDictionary;
}
private static byte[] loadDefaultFeatureGeneratorBytes() {
ByteArrayOutputStream bytes = new ByteArrayOutputStream();
try (InputStream in = TokenNameFinderFactory.class.getResourceAsStream(
"/opennlp/tools/postag/pos-default-features.xml")) {
if (in == null) {
throw new IllegalStateException("Classpath must contain pos-default-features.xml file!");
}
byte[] buf = new byte[1024];
int len;
while ((len = in.read(buf)) > 0) {
bytes.write(buf, 0, len);
}
}
catch (IOException e) {
throw new IllegalStateException("Failed reading from pos-default-features.xml file on classpath!");
}
return bytes.toByteArray();
}
/**
* Creates the {@link AdaptiveFeatureGenerator}. Usually this
* is a set of generators contained in the {@link AggregatedFeatureGenerator}.
*
* Note:
* The generators are created on every call to this method.
*
* @return the feature generator or null if there is no descriptor in the model
*/
public AdaptiveFeatureGenerator createFeatureGenerators() {
if (featureGeneratorBytes == null && artifactProvider != null) {
featureGeneratorBytes = artifactProvider.getArtifact(
POSModel.GENERATOR_DESCRIPTOR_ENTRY_NAME);
}
if (featureGeneratorBytes == null) {
featureGeneratorBytes = loadDefaultFeatureGeneratorBytes();
}
InputStream descriptorIn = new ByteArrayInputStream(featureGeneratorBytes);
AdaptiveFeatureGenerator generator;
try {
generator = GeneratorFactory.create(descriptorIn, key -> {
if (artifactProvider != null) {
return artifactProvider.getArtifact(key);
}
else {
return resources.get(key);
}
});
} catch (InvalidFormatException e) {
// It is assumed that the creation of the feature generation does not
// fail after it succeeded once during model loading.
// But it might still be possible that such an exception is thrown,
// in this case the caller should not be forced to handle the exception
// and a Runtime Exception is thrown instead.
// If the re-creation of the feature generation fails it is assumed
// that this can only be caused by a programming mistake and therefore
// throwing a Runtime Exception is reasonable
throw new IllegalStateException(); // FeatureGeneratorCreationError(e);
} catch (IOException e) {
throw new IllegalStateException("Reading from mem cannot result in an I/O error", e);
}
return generator;
}
@Override
@SuppressWarnings("rawtypes")
public Map<String, ArtifactSerializer> createArtifactSerializersMap() {
Map<String, ArtifactSerializer> serializers = super.createArtifactSerializersMap();
// NOTE: This is only needed for old models and this if can be removed if support is dropped
POSDictionarySerializer.register(serializers);
return serializers;
}
@Override
public Map<String, Object> createArtifactMap() {
Map<String, Object> artifactMap = super.createArtifactMap();
if (posDictionary != null)
artifactMap.put(TAG_DICTIONARY_ENTRY_NAME, posDictionary);
if (ngramDictionary != null)
artifactMap.put(NGRAM_DICTIONARY_ENTRY_NAME, ngramDictionary);
return artifactMap;
}
public TagDictionary createTagDictionary(File dictionary)
throws IOException {
return createTagDictionary(new FileInputStream(dictionary));
}
public TagDictionary createTagDictionary(InputStream in)
throws IOException {
return POSDictionary.create(in);
}
public void setTagDictionary(TagDictionary dictionary) {
if (artifactProvider != null) {
throw new IllegalStateException(
"Can not set tag dictionary while using artifact provider.");
}
this.posDictionary = dictionary;
}
protected Map<String, Object> getResources() {
if (resources != null) {
return resources;
}
return Collections.emptyMap();
}
protected byte[] getFeatureGenerator() {
return featureGeneratorBytes;
}
public TagDictionary getTagDictionary() {
if (this.posDictionary == null && artifactProvider != null)
this.posDictionary = artifactProvider.getArtifact(TAG_DICTIONARY_ENTRY_NAME);
return this.posDictionary;
}
/**
* @deprecated this will be reduced in visibility and later removed
*/
@Deprecated
public Dictionary getDictionary() {
if (this.ngramDictionary == null && artifactProvider != null)
this.ngramDictionary = artifactProvider.getArtifact(NGRAM_DICTIONARY_ENTRY_NAME);
return this.ngramDictionary;
}
@Deprecated
public void setDictionary(Dictionary ngramDict) {
if (artifactProvider != null) {
throw new IllegalStateException(
"Can not set ngram dictionary while using artifact provider.");
}
this.ngramDictionary = ngramDict;
}
public POSContextGenerator getPOSContextGenerator() {
return getPOSContextGenerator(0);
}
public POSContextGenerator getPOSContextGenerator(int cacheSize) {
if (artifactProvider != null) {
Properties manifest = (Properties) artifactProvider.getArtifact("manifest.properties");
String version = manifest.getProperty("OpenNLP-Version");
if (Version.parse(version).getMinor() < 8) {
return new DefaultPOSContextGenerator(cacheSize, getDictionary());
}
}
return new ConfigurablePOSContextGenerator(cacheSize, createFeatureGenerators());
}
public SequenceValidator<String> getSequenceValidator() {
return new DefaultPOSSequenceValidator(getTagDictionary());
}
// TODO: This should not be done anymore for 8 models, they can just
// use the SerializableArtifact interface
public static class POSDictionarySerializer implements ArtifactSerializer<POSDictionary> {
public POSDictionary create(InputStream in) throws IOException {
return POSDictionary.create(new UncloseableInputStream(in));
}
public void serialize(POSDictionary artifact, OutputStream out)
throws IOException {
artifact.serialize(out);
}
@SuppressWarnings("rawtypes")
static void register(Map<String, ArtifactSerializer> factories) {
factories.put("tagdict", new POSDictionarySerializer());
}
}
protected void validatePOSDictionary(POSDictionary posDict,
AbstractModel posModel) throws InvalidFormatException {
Set<String> dictTags = new HashSet<>();
for (String word : posDict) {
Collections.addAll(dictTags, posDict.getTags(word));
}
Set<String> modelTags = new HashSet<>();
for (int i = 0; i < posModel.getNumOutcomes(); i++) {
modelTags.add(posModel.getOutcome(i));
}
if (!modelTags.containsAll(dictTags)) {
StringBuilder unknownTag = new StringBuilder();
for (String d : dictTags) {
if (!modelTags.contains(d)) {
unknownTag.append(d).append(" ");
}
}
throw new InvalidFormatException("Tag dictionary contains tags "
+ "which are unknown by the model! The unknown tags are: "
+ unknownTag);
}
}
@Override
public void validateArtifactMap() throws InvalidFormatException {
// Ensure that the tag dictionary is compatible with the model
Object tagdictEntry = this.artifactProvider
.getArtifact(TAG_DICTIONARY_ENTRY_NAME);
if (tagdictEntry != null) {
if (tagdictEntry instanceof POSDictionary) {
if (!this.artifactProvider.isLoadedFromSerialized()) {
AbstractModel posModel = this.artifactProvider
.getArtifact(POSModel.POS_MODEL_ENTRY_NAME);
POSDictionary posDict = (POSDictionary) tagdictEntry;
validatePOSDictionary(posDict, posModel);
}
} else {
throw new InvalidFormatException(
"POSTag dictionary has wrong type!");
}
}
Object ngramDictEntry = this.artifactProvider
.getArtifact(NGRAM_DICTIONARY_ENTRY_NAME);
if (ngramDictEntry != null && !(ngramDictEntry instanceof Dictionary)) {
throw new InvalidFormatException("NGram dictionary has wrong type!");
}
}
@Deprecated
public static POSTaggerFactory create(String subclassName,
Dictionary ngramDictionary, TagDictionary posDictionary)
throws InvalidFormatException {
if (subclassName == null) {
// will create the default factory
return new POSTaggerFactory(ngramDictionary, posDictionary);
}
try {
POSTaggerFactory theFactory = ExtensionLoader.instantiateExtension(
POSTaggerFactory.class, subclassName);
theFactory.init(ngramDictionary, posDictionary);
return theFactory;
} catch (Exception e) {
String msg = "Could not instantiate the " + subclassName
+ ". The initialization throw an exception.";
throw new InvalidFormatException(msg, e);
}
}
public static POSTaggerFactory create(String subclassName, byte[] featureGeneratorBytes,
Map<String, Object> resources, TagDictionary posDictionary)
throws InvalidFormatException {
POSTaggerFactory theFactory;
if (subclassName == null) {
// will create the default factory
theFactory = new POSTaggerFactory(null, posDictionary);
}
else {
try {
theFactory = ExtensionLoader.instantiateExtension(
POSTaggerFactory.class, subclassName);
} catch (Exception e) {
String msg = "Could not instantiate the " + subclassName
+ ". The initialization throw an exception.";
throw new InvalidFormatException(msg, e);
}
}
theFactory.init(featureGeneratorBytes, resources, posDictionary);
return theFactory;
}
public TagDictionary createEmptyTagDictionary() {
this.posDictionary = new POSDictionary(true);
return this.posDictionary;
}
}