OPENNLP-723 - pcfg support in sandbox (nlp-utils)
diff --git a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java
index 9687d36..f3ae0d0 100644
--- a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java
+++ b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java
@@ -35,6 +35,22 @@
private final String startSymbol;
private final boolean randomExpansion;
+ public Collection<String> getNonTerminalSymbols() {
+ return nonTerminalSymbols;
+ }
+
+ public Collection<String> getTerminalSymbols() {
+ return terminalSymbols;
+ }
+
+ public Collection<Rule> getRules() {
+ return rules;
+ }
+
+ public String getStartSymbol() {
+ return startSymbol;
+ }
+
public ContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Collection<Rule> rules, String startSymbol, boolean randomExpansion) {
assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";
@@ -62,7 +78,6 @@
}
private Collection<String> getTerminals(String word) {
-
if (terminalSymbols.contains(word)) {
Collection<String> c = new LinkedList<String>();
c.add(word);
@@ -87,7 +102,7 @@
ArrayList<Rule> possibleRules = new ArrayList<Rule>();
for (Rule r : rules) {
if (word.equals(r.getEntry())) {
- if (randomExpansion) {
+ if (!randomExpansion) {
return r;
}
possibleRules.add(r);
diff --git a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammar.java b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammar.java
new file mode 100644
index 0000000..12f58c9
--- /dev/null
+++ b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammar.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.opennlp.utils.cfg;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * a probabilistic CFG
+ */
+public class ProbabilisticContextFreeGrammar {
+
+ private final Collection<String> nonTerminalSymbols;
+ private final Collection<String> terminalSymbols;
+ private final Map<Rule, Double> rules;
+ private final String startSymbol;
+ private boolean randomExpansion;
+
+ public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols,
+ Map<Rule, Double> rules, String startSymbol, boolean randomExpansion) {
+
+ assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";
+
+ this.nonTerminalSymbols = nonTerminalSymbols;
+ this.terminalSymbols = terminalSymbols;
+ this.rules = rules;
+ this.startSymbol = startSymbol;
+ this.randomExpansion = randomExpansion;
+ }
+
+ public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Map<Rule, Double> rules, String startSymbol) {
+ this(nonTerminalSymbols, terminalSymbols, rules, startSymbol, false);
+ }
+
+ public Collection<String> getNonTerminalSymbols() {
+ return nonTerminalSymbols;
+ }
+
+ public Collection<String> getTerminalSymbols() {
+ return terminalSymbols;
+ }
+
+ public Map<Rule, Double> getRules() {
+ return rules;
+ }
+
+ public String getStartSymbol() {
+ return startSymbol;
+ }
+
+
+ public String[] leftMostDerivation(String... words) {
+ ArrayList<String> expansion = new ArrayList<String>(words.length);
+
+ assert words.length > 0 && startSymbol.equals(words[0]);
+
+ for (String word : words) {
+ expansion.addAll(getTerminals(word));
+ }
+ return expansion.toArray(new String[expansion.size()]);
+
+ }
+
+ private Collection<String> getTerminals(String word) {
+ if (terminalSymbols.contains(word)) {
+ Collection<String> c = new LinkedList<String>();
+ c.add(word);
+ return c;
+ } else {
+ assert nonTerminalSymbols.contains(word) : "word " + word + " is not contained in non terminals";
+ String[] expansions = getExpansionForSymbol(word);
+ Collection<String> c = new LinkedList<String>();
+ for (String e : expansions) {
+ c.addAll(getTerminals(e));
+ }
+ return c;
+ }
+ }
+
+ private String[] getExpansionForSymbol(String currentSymbol) {
+ Rule r = getRuleForSymbol(currentSymbol);
+ return r.getExpansion();
+ }
+
+ private Rule getRuleForSymbol(String word) {
+ ArrayList<Rule> possibleRules = new ArrayList<Rule>();
+ for (Rule r : rules.keySet()) {
+ if (word.equals(r.getEntry())) {
+ if (!randomExpansion) {
+ return r;
+ }
+ possibleRules.add(r);
+ }
+ }
+ if (possibleRules.size() > 0) {
+ return possibleRules.get(new Random().nextInt(possibleRules.size()));
+ } else {
+ throw new RuntimeException("could not find a rule for expanding symbol " + word);
+ }
+ }
+
+ public BackPointer pi(List<String> sentence, int i, int j, String x) {
+ BackPointer backPointer = new BackPointer(0, 0, null);
+ if (i == j) {
+ Rule rule = new Rule(x, sentence.get(i));
+ double q = q(rule);
+ backPointer = new BackPointer(q, i, rule);
+ } else {
+ double max = 0;
+ for (Rule rule : getNTRules()) {
+ for (int s = i; s < j; s++) {
+ double q = q(rule);
+ BackPointer left = pi(sentence, i, s, rule.getExpansion()[0]);
+ BackPointer right = pi(sentence, s + 1, j, rule.getExpansion()[1]);
+ double cp = q * left.getProbability() * right.getProbability();
+ if (cp > max) {
+ max = cp;
+ backPointer = new BackPointer(max, s, rule, left, right);
+ }
+ }
+ }
+ }
+ return backPointer;
+ }
+
+ public BackPointer cky(List<String> sentence, ProbabilisticContextFreeGrammar pcfg) {
+ BackPointer backPointer = null;
+
+ int n = sentence.size();
+ for (int l = 1; l < n; l++) {
+ for (int i = 0; i < n - l; i++) {
+ int j = i + l;
+ double max = 0;
+ for (String x : pcfg.getNonTerminalSymbols()) {
+ for (Rule r : getRulesForNonTerminal(x)) {
+ for (int s = i; s < j - 1; s++) {
+ double q = q(r);
+ BackPointer left = pi(sentence, i, s, r.getExpansion()[0]);
+ BackPointer right = pi(sentence, s + 1, j, r.getExpansion()[1]);
+ double cp = q * left.getProbability() * right.getProbability();
+ if (cp > max) {
+ max = cp;
+ backPointer = new BackPointer(max, s, r, left, right);
+ }
+ }
+ }
+ }
+ }
+ }
+ return backPointer;
+ }
+
+ private Collection<Rule> getRulesForNonTerminal(String x) {
+ LinkedList<Rule> ntRules = new LinkedList<Rule>();
+ for (Rule r : rules.keySet()) {
+ if (x.equals(r.getEntry()) && nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
+ ntRules.add(r);
+ }
+ }
+ return ntRules;
+ }
+
+ private Collection<Rule> getNTRules() {
+ Collection<Rule> ntRules = new LinkedList<Rule>();
+ for (Rule r : rules.keySet()) {
+ if (nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
+ ntRules.add(r);
+ }
+ }
+ return ntRules;
+ }
+
+ private double q(Rule rule) {
+ return rules.keySet().contains(rule) ? rules.get(rule) : 0;
+ }
+
+ public class BackPointer {
+
+ private final double probability;
+ private final int splitPoint;
+ private final Rule rule;
+ private BackPointer leftTree;
+ private BackPointer rightTree;
+
+ private BackPointer(double probability, int splitPoint, Rule rule) {
+ this.probability = probability;
+ this.splitPoint = splitPoint;
+ this.rule = rule;
+ }
+
+ public BackPointer(double probability, int splitPoint, Rule rule, BackPointer leftTree, BackPointer rightTree) {
+ this.probability = probability;
+ this.splitPoint = splitPoint;
+ this.rule = rule;
+ this.leftTree = leftTree;
+ this.rightTree = rightTree;
+ }
+
+ public double getProbability() {
+ return probability;
+ }
+
+ public int getSplitPoint() {
+ return splitPoint;
+ }
+
+ public Rule getRule() {
+ return rule;
+ }
+
+ public BackPointer getLeftTree() {
+ return leftTree;
+ }
+
+ public BackPointer getRightTree() {
+ return rightTree;
+ }
+
+ @Override
+ public String toString() {
+ return "BackPointer{" +
+ "probability=" + probability +
+ ", splitPoint=" + splitPoint +
+ ", rule=" + rule +
+ ", leftTree=" + leftTree +
+ ", rightTree=" + rightTree +
+ '}';
+ }
+ }
+
+}
diff --git a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java
index f6c3b7a..85b70a6 100644
--- a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java
+++ b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java
@@ -42,8 +42,8 @@
@Override
public int compareTo(Rule o) {
- int c = entry.compareTo(o.getEntry());
- return c != 0 ? c : Arrays.toString(expansion).compareTo(Arrays.toString(o.getExpansion()));
+ int c = entry.compareTo(o.getEntry());
+ return c != 0 ? c : Arrays.toString(expansion).compareTo(Arrays.toString(o.getExpansion()));
}
@Override
@@ -53,10 +53,8 @@
Rule rule = (Rule) o;
- if (entry != null ? !entry.equals(rule.entry) : rule.entry != null) return false;
- if (!Arrays.equals(expansion, rule.expansion)) return false;
+ return !(entry != null ? !entry.equals(rule.entry) : rule.entry != null) && Arrays.equals(expansion, rule.expansion);
- return true;
}
@Override
@@ -65,4 +63,12 @@
result = 31 * result + (expansion != null ? Arrays.hashCode(expansion) : 0);
return result;
}
+
+ @Override
+ public String toString() {
+ return "{" +
+ "'" + entry + '\'' +
+ " -> " + Arrays.toString(expansion) +
+ '}';
+ }
}
diff --git a/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammarTest.java b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammarTest.java
new file mode 100644
index 0000000..8a991e0
--- /dev/null
+++ b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammarTest.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.opennlp.utils.cfg;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Testcase for {@link org.apache.opennlp.utils.cfg.ProbabilisticContextFreeGrammar}
+ */
+public class ProbabilisticContextFreeGrammarTest {
+
+ private static LinkedList<String> nonTerminals;
+ private static String startSymbol;
+ private static LinkedList<String> terminals;
+ private static Map<Rule, Double> rules;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ nonTerminals = new LinkedList<String>();
+ nonTerminals.add("S");
+ nonTerminals.add("NP");
+ nonTerminals.add("VP");
+ nonTerminals.add("DT");
+ nonTerminals.add("Vi");
+ nonTerminals.add("Vt");
+ nonTerminals.add("NN");
+ nonTerminals.add("IN");
+ nonTerminals.add("NNP");
+ nonTerminals.add("Adv");
+
+ startSymbol = "S";
+
+ terminals = new LinkedList<String>();
+ terminals.add("works");
+ terminals.add("saw");
+ terminals.add("man");
+ terminals.add("woman");
+ terminals.add("dog");
+ terminals.add("the");
+ terminals.add("with");
+ terminals.add("in");
+ terminals.add("joe");
+ terminals.add("john");
+ terminals.add("sam");
+ terminals.add("michael");
+ terminals.add("michelle");
+ terminals.add("scarlett");
+ terminals.add("and");
+ terminals.add("but");
+ terminals.add("while");
+ terminals.add("of");
+ terminals.add("for");
+ terminals.add("badly");
+ terminals.add("nicely");
+
+ rules = new HashMap<Rule, Double>();
+ rules.put(new Rule("S", "NP", "VP"), 1d);
+ rules.put(new Rule("VP", "Vi", "Adv"), 0.3);
+ rules.put(new Rule("VP", "Vt", "NP"), 0.7);
+ rules.put(new Rule("NP", "DT", "NN"), 1d);
+ rules.put(new Rule("Vi", "works"), 1d);
+ rules.put(new Rule("Vt", "saw"), 1d);
+ rules.put(new Rule("NN", "man"), 0.5);
+ rules.put(new Rule("NN", "woman"), 0.2);
+ rules.put(new Rule("NN", "dog"), 0.3);
+ rules.put(new Rule("DT", "the"), 1d);
+ rules.put(new Rule("IN", "with"), 0.2);
+ rules.put(new Rule("IN", "in"), 0.1);
+ rules.put(new Rule("IN", "for"), 0.4);
+ rules.put(new Rule("IN", "of"), 0.4);
+ rules.put(new Rule("NNP", "joe"), 0.1);
+ rules.put(new Rule("NNP", "john"), 0.1);
+ rules.put(new Rule("NNP", "sam"), 0.1);
+ rules.put(new Rule("NNP", "michael"), 0.1);
+ rules.put(new Rule("NNP", "michelle"), 0.1);
+ rules.put(new Rule("NNP", "scarlett"), 0.5);
+ rules.put(new Rule("Adv", "badly"), 0.3);
+ rules.put(new Rule("Adv", "nicely"), 0.7);
+ }
+
+ @Test
+ public void testIntermediateProbability() throws Exception {
+ ArrayList<String> sentence = new ArrayList<String>();
+ sentence.add("the");
+ sentence.add("dog");
+ sentence.add("saw");
+ sentence.add("the");
+ sentence.add("man");
+ sentence.add("with");
+ sentence.add("the");
+ sentence.add("woman");
+
+ ProbabilisticContextFreeGrammar pcfg = new ProbabilisticContextFreeGrammar(nonTerminals, terminals, rules, startSymbol);
+
+ double pi = pcfg.pi(sentence, 0, 1, pcfg.getStartSymbol()).getProbability();
+ assertTrue(pi <= 1 && pi >= 0);
+
+ pi = pcfg.pi(sentence, 2, 7, "VP").getProbability();
+ assertTrue(pi <= 1 && pi >= 0);
+ }
+
+ @Test
+ public void testFullSentenceCKY() throws Exception {
+ ProbabilisticContextFreeGrammar pcfg = new ProbabilisticContextFreeGrammar(nonTerminals, terminals, rules, startSymbol, true);
+
+ // fixed sentence one
+ List<String> sentence = new ArrayList<String>();
+ sentence.add("the");
+ sentence.add("dog");
+ sentence.add("saw");
+ sentence.add("the");
+ sentence.add("man");
+
+ ProbabilisticContextFreeGrammar.BackPointer backPointer = pcfg.cky(sentence, pcfg);
+ check(pcfg, backPointer, sentence);
+
+ // fixed sentence two
+ sentence = new ArrayList<String>();
+ sentence.add("the");
+ sentence.add("man");
+ sentence.add("works");
+ sentence.add("nicely");
+
+ backPointer = pcfg.cky(sentence, pcfg);
+ check(pcfg, backPointer, sentence);
+
+ // random sentence generated by the grammar
+ String[] expansion = pcfg.leftMostDerivation("S");
+ sentence = Arrays.asList(expansion);
+
+ backPointer = pcfg.cky(sentence, pcfg);
+ check(pcfg, backPointer, sentence);
+ }
+
+ private void check(ProbabilisticContextFreeGrammar pcfg, ProbabilisticContextFreeGrammar.BackPointer backPointer, List<String> sentence) {
+ Rule rule = backPointer.getRule();
+ assertNotNull(rule);
+ assertEquals(pcfg.getStartSymbol(), rule.getEntry());
+ int s = backPointer.getSplitPoint();
+ assertTrue(s >= 0);
+ double pi = backPointer.getProbability();
+ assertTrue(pi <= 1 && pi >= 0);
+ List<String> expandedTerminals = getTerminals(backPointer);
+ for (int i = 0; i < sentence.size(); i++) {
+ assertEquals(sentence.get(i), expandedTerminals.get(i));
+ }
+
+ }
+
+ private List<String> getTerminals(ProbabilisticContextFreeGrammar.BackPointer backPointer) {
+ if (backPointer.getLeftTree() == null && backPointer.getRightTree() == null) {
+ return Arrays.asList(backPointer.getRule().getExpansion());
+ }
+
+ ArrayList<String> list = new ArrayList<String>();
+ list.addAll(getTerminals(backPointer.getLeftTree()));
+ list.addAll(getTerminals(backPointer.getRightTree()));
+ return list;
+ }
+
+}
diff --git a/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java
index 40a2bd7..9c43087 100644
--- a/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java
+++ b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java
@@ -28,8 +28,16 @@
* Testcase for {@link Rule}
*/
public class RuleTest {
+
@Test
public void testEquals() throws Exception {
+ Rule r1 = new Rule("NP", "NP", "PP");
+ Rule r2 = new Rule("NP", "NP", "PP");
+ assertEquals(r1, r2);
+ }
+
+ @Test
+ public void testNotEquals() throws Exception {
Rule r1 = new Rule("NP", "DT", "NN");
Rule r2 = new Rule("NP", "NP", "PP");
assertNotEquals(r1, r2);