blob: 26c2abd868eddb74cc4d9ee061e28fb42aaaf7bb [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.opennlp.utils.cfg;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* a probabilistic CFG
*/
public class ProbabilisticContextFreeGrammar {
private final Collection<String> nonTerminalSymbols;
private final Collection<String> terminalSymbols;
private final Map<Rule, Double> rules;
private final String startSymbol;
private boolean randomExpansion;
private static final Rule emptyRule = new Rule("E", "");
private static final String nonTerminalMatcher = "[\\w\\~\\*\\-\\.\\,\\'\\:\\_\\\"]";
private static final String terminalMatcher = "[òàùìèé\\|\\w\\'\\.\\,\\:\\_Ù\\?È\\%\\;À\\-\\\"]";
private static final Pattern terminalPattern = Pattern.compile("\\(("+nonTerminalMatcher+"+)\\s("+terminalMatcher+"+)\\)");
private static final Pattern nonTerminalPattern = Pattern.compile(
"\\(("+nonTerminalMatcher+"+)" + // source NT
"\\s("+nonTerminalMatcher+"+)(\\s("+nonTerminalMatcher+"+))*\\)" // expansion NTs
);
public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols,
Map<Rule, Double> rules, String startSymbol, boolean randomExpansion) {
assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";
this.nonTerminalSymbols = nonTerminalSymbols;
this.terminalSymbols = terminalSymbols;
this.rules = rules;
this.startSymbol = startSymbol;
this.randomExpansion = randomExpansion;
}
public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Map<Rule, Double> rules, String startSymbol) {
this(nonTerminalSymbols, terminalSymbols, rules, startSymbol, false);
}
public Collection<String> getNonTerminalSymbols() {
return nonTerminalSymbols;
}
public Collection<String> getTerminalSymbols() {
return terminalSymbols;
}
public Map<Rule, Double> getRules() {
return rules;
}
public String getStartSymbol() {
return startSymbol;
}
public String[] leftMostDerivation(String... words) {
ArrayList<String> expansion = new ArrayList<String>(words.length);
assert words.length > 0 && startSymbol.equals(words[0]);
for (String word : words) {
expansion.addAll(getTerminals(word));
}
return expansion.toArray(new String[expansion.size()]);
}
private Collection<String> getTerminals(String word) {
if (terminalSymbols.contains(word)) {
Collection<String> c = new LinkedList<String>();
c.add(word);
return c;
} else {
assert nonTerminalSymbols.contains(word) : "word " + word + " is not contained in non terminals";
String[] expansions = getExpansionForSymbol(word);
Collection<String> c = new LinkedList<String>();
for (String e : expansions) {
c.addAll(getTerminals(e));
}
return c;
}
}
private String[] getExpansionForSymbol(String currentSymbol) {
Rule r = getRuleForSymbol(currentSymbol);
return r.getExpansion();
}
private Rule getRuleForSymbol(String word) {
ArrayList<Rule> possibleRules = new ArrayList<Rule>();
for (Rule r : rules.keySet()) {
if (word.equals(r.getEntry())) {
if (!randomExpansion) {
return r;
}
possibleRules.add(r);
}
}
if (possibleRules.size() > 0) {
return possibleRules.get(new Random().nextInt(possibleRules.size()));
} else {
throw new RuntimeException("could not find a rule for expanding symbol " + word);
}
}
public ParseTree pi(List<String> sentence, int i, int j, String x) {
ParseTree parseTree = new ParseTree(0, 0, null);
if (i == j) {
Rule rule = new Rule(x, sentence.get(i));
double q = q(rule);
parseTree = new ParseTree(q, i, rule);
} else {
double max = 0;
for (Rule rule : getNTRules()) {
for (int s = i; s < j; s++) {
double q = q(rule);
ParseTree left = pi(sentence, i, s, rule.getExpansion()[0]);
ParseTree right = pi(sentence, s + 1, j, rule.getExpansion()[1]);
double cp = q * left.getProbability() * right.getProbability();
if (cp > max) {
max = cp;
parseTree = new ParseTree(max, s, rule, left, right);
}
}
}
}
return parseTree;
}
public ParseTree cky(List<String> sentence) {
ParseTree parseTree = null;
int n = sentence.size();
for (int l = 1; l < n; l++) {
for (int i = 0; i < n - l; i++) {
int j = i + l;
double max = 0;
for (String x : getNonTerminalSymbols()) {
for (Rule r : getRulesForNonTerminal(x)) {
for (int s = i; s < j - 1; s++) {
double q = q(r);
ParseTree left = pi(sentence, i, s, r.getExpansion()[0]);
ParseTree right = pi(sentence, s + 1, j, r.getExpansion()[1]);
double cp = q * left.getProbability() * right.getProbability();
if (cp > max) {
max = cp;
parseTree = new ParseTree(max, s, r, left, right);
}
}
}
}
}
}
return parseTree;
}
private Collection<Rule> getRulesForNonTerminal(String x) {
LinkedList<Rule> ntRules = new LinkedList<Rule>();
for (Rule r : rules.keySet()) {
String[] expansion = r.getExpansion();
if (expansion.length == 2 && x.equals(r.getEntry()) && nonTerminalSymbols.contains(expansion[0]) && nonTerminalSymbols.contains(expansion[1])) {
ntRules.add(r);
}
}
return ntRules;
}
private Collection<Rule> getNTRules() {
Collection<Rule> ntRules = new LinkedList<Rule>();
for (Rule r : rules.keySet()) {
String[] expansion = r.getExpansion();
if (expansion.length == 2 && nonTerminalSymbols.contains(expansion[0]) && nonTerminalSymbols.contains(expansion[1])) {
ntRules.add(r);
}
}
return ntRules;
}
private double q(Rule rule) {
return rules.keySet().contains(rule) ? rules.get(rule) : 0;
}
public class ParseTree {
private final double probability;
private final int splitPoint;
private final Rule rule;
private ParseTree leftTree;
private ParseTree rightTree;
private ParseTree(double probability, int splitPoint, Rule rule) {
this.probability = probability;
this.splitPoint = splitPoint;
this.rule = rule;
}
public ParseTree(double probability, int splitPoint, Rule rule, ParseTree leftTree, ParseTree rightTree) {
this.probability = probability;
this.splitPoint = splitPoint;
this.rule = rule;
this.leftTree = leftTree;
this.rightTree = rightTree;
}
public double getProbability() {
return probability;
}
public int getSplitPoint() {
return splitPoint;
}
public Rule getRule() {
return rule;
}
public ParseTree getLeftTree() {
return leftTree;
}
public ParseTree getRightTree() {
return rightTree;
}
@Override
public String toString() {
if (getRule() != emptyRule) {
return "(" +
rule.getEntry() + " " +
(leftTree != null && rightTree != null ?
leftTree.toString() + " " + rightTree.toString() :
rule.getExpansion()[0]
) +
')';
} else {
return "";
}
}
}
public static Map<Rule, Double> parseRules(String... parseTreeString) {
Map<Rule, Double> rules = new HashMap<>();
parseRules(rules, false, parseTreeString);
return rules;
}
public static void parseRules(Map<Rule, Double> rules, boolean trim, String... parseStrings) {
parseGrammar(rules, "S", trim, parseStrings);
}
public static ProbabilisticContextFreeGrammar parseGrammar(boolean trim, String... parseTreeStrings) {
return parseGrammar(new HashMap<Rule, Double>(), "S", trim, parseTreeStrings);
}
public static ProbabilisticContextFreeGrammar parseGrammar(String... parseTreeStrings) {
return parseGrammar(new HashMap<Rule, Double>(), "S", true, parseTreeStrings);
}
public static ProbabilisticContextFreeGrammar parseGrammar(Map<Rule, Double> rulesMap, String startSymbol, boolean trim, String... parseStrings) {
Map<Rule, Double> rules = new HashMap<>();
Collection<String> nonTerminals = new HashSet<>();
Collection<String> terminals = new HashSet<>();
for (String parseTreeString : parseStrings) {
if (trim) {
parseTreeString = parseTreeString.replaceAll("\n", "").replaceAll("\t", "").replaceAll("\\s+", " ");
}
String toConsume = String.valueOf(parseTreeString);
Matcher m = terminalPattern.matcher(parseTreeString);
while (m.find()) {
String nt = m.group(1);
String t = m.group(2);
Rule key = new Rule(nt, t);
if (!rules.containsKey(key)) {
rules.put(key, 1d);
terminals.add(t);
// System.err.println(key);
}
toConsume = toConsume.replace(m.group(), nt);
}
while (toConsume.contains(" ") && !toConsume.trim().equals("( " + startSymbol + " )")) {
Matcher m2 = nonTerminalPattern.matcher(toConsume);
while (m2.find()) {
String nt = m2.group(1);
String t1 = m2.group(2);
String t2 = m2.group(3);
Rule key;
if (t2 != null) {
String[] t2s = t2.trim().split(" ");
String[] nts = new String[t2s.length + 1];
nts[0] = t1;
System.arraycopy(t2s, 0, nts, 1, t2s.length);
key = new Rule(nt, nts);
nonTerminals.addAll(Arrays.asList(nts));
} else {
key = new Rule(nt, t1);
nonTerminals.add(t1);
}
nonTerminals.add(key.getEntry());
if (!rules.containsKey(key)) {
rules.put(key, 1d);
// startSymbol = key.getEntry();
// System.err.println(key);
}
toConsume = toConsume.replace(m2.group(), nt);
}
}
}
// TODO : check/adjust rules to make them respect CNF
// TODO : adjust probabilities based on term frequencies
for (Map.Entry<Rule, Double> entry : rules.entrySet()) {
normalize(entry.getKey(), nonTerminals, terminals, rulesMap);
}
return new ProbabilisticContextFreeGrammar(nonTerminals, terminals, rulesMap, startSymbol, true);
}
private static void normalize(Rule rule, Collection<String> nonTerminals, Collection<String> terminals, Map<Rule, Double> rulesMap) {
String[] expansion = rule.getExpansion();
if (expansion.length == 1) {
if (!terminals.contains(expansion[0])) {
if (nonTerminals.contains(expansion[0])) {
// nt1 -> nt2 should be expanded in nt1 -> nt2,E
rulesMap.put(new Rule(rule.getEntry(), expansion[0], emptyRule.getEntry()), 1d);
if (rulesMap.containsKey(emptyRule)) {
rulesMap.put(emptyRule, 1d);
}
} else {
throw new RuntimeException("rule "+rule+" expands to neither a terminal or non terminal");
}
} else {
rulesMap.put(rule, 1d);
}
} else if (expansion.length > 2){
// nt1 -> nt2,nt3,...,ntn should be collapsed to a hierarchy of ntX -> ntY,ntZ rules
String nt2 = expansion[0];
int seed = nonTerminals.size();
String generatedNT = "GEN~" + seed;
nonTerminals.add(generatedNT);
Rule newRule = new Rule(rule.getEntry(), nt2, generatedNT);
rulesMap.put(newRule, 1d);
Rule chainedRule = new Rule(generatedNT, Arrays.copyOfRange(expansion, 1, expansion.length - 1));
rulesMap.put(chainedRule, 1d);
normalize(chainedRule, nonTerminals, terminals, rulesMap);
} else {
rulesMap.put(rule, 1d);
}
}
}