blob: bb84425bbeecb2f2c9328f7ba9f94d1a7ee7eabb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.opennlp.utils.cfg;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
/**
* a probabilistic CFG
*/
public class ProbabilisticContextFreeGrammar {
private final Collection<String> nonTerminalSymbols;
private final Collection<String> terminalSymbols;
private final Map<Rule, Double> rules;
private final String startSymbol;
private boolean randomExpansion;
public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols,
Map<Rule, Double> rules, String startSymbol, boolean randomExpansion) {
assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";
this.nonTerminalSymbols = nonTerminalSymbols;
this.terminalSymbols = terminalSymbols;
this.rules = rules;
this.startSymbol = startSymbol;
this.randomExpansion = randomExpansion;
}
public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Map<Rule, Double> rules, String startSymbol) {
this(nonTerminalSymbols, terminalSymbols, rules, startSymbol, false);
}
public Collection<String> getNonTerminalSymbols() {
return nonTerminalSymbols;
}
public Collection<String> getTerminalSymbols() {
return terminalSymbols;
}
public Map<Rule, Double> getRules() {
return rules;
}
public String getStartSymbol() {
return startSymbol;
}
public String[] leftMostDerivation(String... words) {
ArrayList<String> expansion = new ArrayList<String>(words.length);
assert words.length > 0 && startSymbol.equals(words[0]);
for (String word : words) {
expansion.addAll(getTerminals(word));
}
return expansion.toArray(new String[expansion.size()]);
}
private Collection<String> getTerminals(String word) {
if (terminalSymbols.contains(word)) {
Collection<String> c = new LinkedList<String>();
c.add(word);
return c;
} else {
assert nonTerminalSymbols.contains(word) : "word " + word + " is not contained in non terminals";
String[] expansions = getExpansionForSymbol(word);
Collection<String> c = new LinkedList<String>();
for (String e : expansions) {
c.addAll(getTerminals(e));
}
return c;
}
}
private String[] getExpansionForSymbol(String currentSymbol) {
Rule r = getRuleForSymbol(currentSymbol);
return r.getExpansion();
}
private Rule getRuleForSymbol(String word) {
ArrayList<Rule> possibleRules = new ArrayList<Rule>();
for (Rule r : rules.keySet()) {
if (word.equals(r.getEntry())) {
if (!randomExpansion) {
return r;
}
possibleRules.add(r);
}
}
if (possibleRules.size() > 0) {
return possibleRules.get(new Random().nextInt(possibleRules.size()));
} else {
throw new RuntimeException("could not find a rule for expanding symbol " + word);
}
}
public BackPointer pi(List<String> sentence, int i, int j, String x) {
BackPointer backPointer = new BackPointer(0, 0, null);
if (i == j) {
Rule rule = new Rule(x, sentence.get(i));
double q = q(rule);
backPointer = new BackPointer(q, i, rule);
} else {
double max = 0;
for (Rule rule : getNTRules()) {
for (int s = i; s < j; s++) {
double q = q(rule);
BackPointer left = pi(sentence, i, s, rule.getExpansion()[0]);
BackPointer right = pi(sentence, s + 1, j, rule.getExpansion()[1]);
double cp = q * left.getProbability() * right.getProbability();
if (cp > max) {
max = cp;
backPointer = new BackPointer(max, s, rule, left, right);
}
}
}
}
return backPointer;
}
public BackPointer cky(List<String> sentence) {
BackPointer backPointer = null;
int n = sentence.size();
for (int l = 1; l < n; l++) {
for (int i = 0; i < n - l; i++) {
int j = i + l;
double max = 0;
for (String x : getNonTerminalSymbols()) {
for (Rule r : getRulesForNonTerminal(x)) {
for (int s = i; s < j - 1; s++) {
double q = q(r);
BackPointer left = pi(sentence, i, s, r.getExpansion()[0]);
BackPointer right = pi(sentence, s + 1, j, r.getExpansion()[1]);
double cp = q * left.getProbability() * right.getProbability();
if (cp > max) {
max = cp;
backPointer = new BackPointer(max, s, r, left, right);
}
}
}
}
}
}
return backPointer;
}
private Collection<Rule> getRulesForNonTerminal(String x) {
LinkedList<Rule> ntRules = new LinkedList<Rule>();
for (Rule r : rules.keySet()) {
if (x.equals(r.getEntry()) && nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
ntRules.add(r);
}
}
return ntRules;
}
private Collection<Rule> getNTRules() {
Collection<Rule> ntRules = new LinkedList<Rule>();
for (Rule r : rules.keySet()) {
if (nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
ntRules.add(r);
}
}
return ntRules;
}
private double q(Rule rule) {
return rules.keySet().contains(rule) ? rules.get(rule) : 0;
}
public class BackPointer {
private final double probability;
private final int splitPoint;
private final Rule rule;
private BackPointer leftTree;
private BackPointer rightTree;
private BackPointer(double probability, int splitPoint, Rule rule) {
this.probability = probability;
this.splitPoint = splitPoint;
this.rule = rule;
}
public BackPointer(double probability, int splitPoint, Rule rule, BackPointer leftTree, BackPointer rightTree) {
this.probability = probability;
this.splitPoint = splitPoint;
this.rule = rule;
this.leftTree = leftTree;
this.rightTree = rightTree;
}
public double getProbability() {
return probability;
}
public int getSplitPoint() {
return splitPoint;
}
public Rule getRule() {
return rule;
}
public BackPointer getLeftTree() {
return leftTree;
}
public BackPointer getRightTree() {
return rightTree;
}
@Override
public String toString() {
return "(" +
rule.getEntry() + " " +
(leftTree != null && rightTree != null ?
leftTree.toString() + " " + rightTree.toString() :
rule.getExpansion()[0]
) +
')';
}
}
}