OPENNLP-723 - pcfg support in sandbox (nlp-utils)
diff --git a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java
index 9687d36..f3ae0d0 100644
--- a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java
+++ b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ContextFreeGrammar.java
@@ -35,6 +35,22 @@
   private final String startSymbol;
   private final boolean randomExpansion;
 
+  public Collection<String> getNonTerminalSymbols() {
+    return nonTerminalSymbols;
+  }
+
+  public Collection<String> getTerminalSymbols() {
+    return terminalSymbols;
+  }
+
+  public Collection<Rule> getRules() {
+    return rules;
+  }
+
+  public String getStartSymbol() {
+    return startSymbol;
+  }
+
   public ContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Collection<Rule> rules, String startSymbol, boolean randomExpansion) {
     assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";
 
@@ -62,7 +78,6 @@
   }
 
   private Collection<String> getTerminals(String word) {
-
     if (terminalSymbols.contains(word)) {
       Collection<String> c = new LinkedList<String>();
       c.add(word);
@@ -87,7 +102,7 @@
     ArrayList<Rule> possibleRules = new ArrayList<Rule>();
     for (Rule r : rules) {
       if (word.equals(r.getEntry())) {
-        if (randomExpansion) {
+        if (!randomExpansion) {
           return r;
         }
         possibleRules.add(r);
diff --git a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammar.java b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammar.java
new file mode 100644
index 0000000..12f58c9
--- /dev/null
+++ b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammar.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.opennlp.utils.cfg;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * a probabilistic CFG
+ */
+public class ProbabilisticContextFreeGrammar {
+
+  private final Collection<String> nonTerminalSymbols;
+  private final Collection<String> terminalSymbols;
+  private final Map<Rule, Double> rules;
+  private final String startSymbol;
+  private boolean randomExpansion;
+
+  public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols,
+                                         Map<Rule, Double> rules, String startSymbol, boolean randomExpansion) {
+
+    assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set";
+
+    this.nonTerminalSymbols = nonTerminalSymbols;
+    this.terminalSymbols = terminalSymbols;
+    this.rules = rules;
+    this.startSymbol = startSymbol;
+    this.randomExpansion = randomExpansion;
+  }
+
+  public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Map<Rule, Double> rules, String startSymbol) {
+    this(nonTerminalSymbols, terminalSymbols, rules, startSymbol, false);
+  }
+
+  public Collection<String> getNonTerminalSymbols() {
+    return nonTerminalSymbols;
+  }
+
+  public Collection<String> getTerminalSymbols() {
+    return terminalSymbols;
+  }
+
+  public Map<Rule, Double> getRules() {
+    return rules;
+  }
+
+  public String getStartSymbol() {
+    return startSymbol;
+  }
+
+
+  public String[] leftMostDerivation(String... words) {
+    ArrayList<String> expansion = new ArrayList<String>(words.length);
+
+    assert words.length > 0 && startSymbol.equals(words[0]);
+
+    for (String word : words) {
+      expansion.addAll(getTerminals(word));
+    }
+    return expansion.toArray(new String[expansion.size()]);
+
+  }
+
+  private Collection<String> getTerminals(String word) {
+    if (terminalSymbols.contains(word)) {
+      Collection<String> c = new LinkedList<String>();
+      c.add(word);
+      return c;
+    } else {
+      assert nonTerminalSymbols.contains(word) : "word " + word + " is not contained in non terminals";
+      String[] expansions = getExpansionForSymbol(word);
+      Collection<String> c = new LinkedList<String>();
+      for (String e : expansions) {
+        c.addAll(getTerminals(e));
+      }
+      return c;
+    }
+  }
+
+  private String[] getExpansionForSymbol(String currentSymbol) {
+    Rule r = getRuleForSymbol(currentSymbol);
+    return r.getExpansion();
+  }
+
+  private Rule getRuleForSymbol(String word) {
+    ArrayList<Rule> possibleRules = new ArrayList<Rule>();
+    for (Rule r : rules.keySet()) {
+      if (word.equals(r.getEntry())) {
+        if (!randomExpansion) {
+          return r;
+        }
+        possibleRules.add(r);
+      }
+    }
+    if (possibleRules.size() > 0) {
+      return possibleRules.get(new Random().nextInt(possibleRules.size()));
+    } else {
+      throw new RuntimeException("could not find a rule for expanding symbol " + word);
+    }
+  }
+
+  public BackPointer pi(List<String> sentence, int i, int j, String x) {
+    BackPointer backPointer = new BackPointer(0, 0, null);
+    if (i == j) {
+      Rule rule = new Rule(x, sentence.get(i));
+      double q = q(rule);
+      backPointer = new BackPointer(q, i, rule);
+    } else {
+      double max = 0;
+      for (Rule rule : getNTRules()) {
+        for (int s = i; s < j; s++) {
+          double q = q(rule);
+          BackPointer left = pi(sentence, i, s, rule.getExpansion()[0]);
+          BackPointer right = pi(sentence, s + 1, j, rule.getExpansion()[1]);
+          double cp = q * left.getProbability() * right.getProbability();
+          if (cp > max) {
+            max = cp;
+            backPointer = new BackPointer(max, s, rule, left, right);
+          }
+        }
+      }
+    }
+    return backPointer;
+  }
+
+  public BackPointer cky(List<String> sentence, ProbabilisticContextFreeGrammar pcfg) {
+    BackPointer backPointer = null;
+
+    int n = sentence.size();
+    for (int l = 1; l < n; l++) {
+      for (int i = 0; i < n - l; i++) {
+        int j = i + l;
+        double max = 0;
+        for (String x : pcfg.getNonTerminalSymbols()) {
+          for (Rule r : getRulesForNonTerminal(x)) {
+            for (int s = i; s < j - 1; s++) {
+              double q = q(r);
+              BackPointer left = pi(sentence, i, s, r.getExpansion()[0]);
+              BackPointer right = pi(sentence, s + 1, j, r.getExpansion()[1]);
+              double cp = q * left.getProbability() * right.getProbability();
+              if (cp > max) {
+                max = cp;
+                backPointer = new BackPointer(max, s, r, left, right);
+              }
+            }
+          }
+        }
+      }
+    }
+    return backPointer;
+  }
+
+  private Collection<Rule> getRulesForNonTerminal(String x) {
+    LinkedList<Rule> ntRules = new LinkedList<Rule>();
+    for (Rule r : rules.keySet()) {
+      if (x.equals(r.getEntry()) && nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
+        ntRules.add(r);
+      }
+    }
+    return ntRules;
+  }
+
+  private Collection<Rule> getNTRules() {
+    Collection<Rule> ntRules = new LinkedList<Rule>();
+    for (Rule r : rules.keySet()) {
+      if (nonTerminalSymbols.contains(r.getExpansion()[0]) && nonTerminalSymbols.contains(r.getExpansion()[1])) {
+        ntRules.add(r);
+      }
+    }
+    return ntRules;
+  }
+
+  private double q(Rule rule) {
+    return rules.keySet().contains(rule) ? rules.get(rule) : 0;
+  }
+
+  public class BackPointer {
+
+    private final double probability;
+    private final int splitPoint;
+    private final Rule rule;
+    private BackPointer leftTree;
+    private BackPointer rightTree;
+
+    private BackPointer(double probability, int splitPoint, Rule rule) {
+      this.probability = probability;
+      this.splitPoint = splitPoint;
+      this.rule = rule;
+    }
+
+    public BackPointer(double probability, int splitPoint, Rule rule, BackPointer leftTree, BackPointer rightTree) {
+      this.probability = probability;
+      this.splitPoint = splitPoint;
+      this.rule = rule;
+      this.leftTree = leftTree;
+      this.rightTree = rightTree;
+    }
+
+    public double getProbability() {
+      return probability;
+    }
+
+    public int getSplitPoint() {
+      return splitPoint;
+    }
+
+    public Rule getRule() {
+      return rule;
+    }
+
+    public BackPointer getLeftTree() {
+      return leftTree;
+    }
+
+    public BackPointer getRightTree() {
+      return rightTree;
+    }
+
+    @Override
+    public String toString() {
+      return "BackPointer{" +
+              "probability=" + probability +
+              ", splitPoint=" + splitPoint +
+              ", rule=" + rule +
+              ", leftTree=" + leftTree +
+              ", rightTree=" + rightTree +
+              '}';
+    }
+  }
+
+}
diff --git a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java
index f6c3b7a..85b70a6 100644
--- a/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java
+++ b/nlp-utils/src/main/java/org/apache/opennlp/utils/cfg/Rule.java
@@ -42,8 +42,8 @@
 
   @Override
   public int compareTo(Rule o) {
-      int c = entry.compareTo(o.getEntry());
-      return c != 0 ? c : Arrays.toString(expansion).compareTo(Arrays.toString(o.getExpansion()));
+    int c = entry.compareTo(o.getEntry());
+    return c != 0 ? c : Arrays.toString(expansion).compareTo(Arrays.toString(o.getExpansion()));
   }
 
   @Override
@@ -53,10 +53,8 @@
 
     Rule rule = (Rule) o;
 
-    if (entry != null ? !entry.equals(rule.entry) : rule.entry != null) return false;
-    if (!Arrays.equals(expansion, rule.expansion)) return false;
+    return !(entry != null ? !entry.equals(rule.entry) : rule.entry != null) && Arrays.equals(expansion, rule.expansion);
 
-    return true;
   }
 
   @Override
@@ -65,4 +63,12 @@
     result = 31 * result + (expansion != null ? Arrays.hashCode(expansion) : 0);
     return result;
   }
+
+  @Override
+  public String toString() {
+    return "{" +
+            "'" + entry + '\'' +
+            " -> " + Arrays.toString(expansion) +
+            '}';
+  }
 }
diff --git a/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammarTest.java b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammarTest.java
new file mode 100644
index 0000000..8a991e0
--- /dev/null
+++ b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/ProbabilisticContextFreeGrammarTest.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.opennlp.utils.cfg;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Testcase for {@link org.apache.opennlp.utils.cfg.ProbabilisticContextFreeGrammar}
+ */
+public class ProbabilisticContextFreeGrammarTest {
+
+  private static LinkedList<String> nonTerminals;
+  private static String startSymbol;
+  private static LinkedList<String> terminals;
+  private static Map<Rule, Double> rules;
+
+  @BeforeClass
+  public static void setUp() throws Exception {
+    nonTerminals = new LinkedList<String>();
+    nonTerminals.add("S");
+    nonTerminals.add("NP");
+    nonTerminals.add("VP");
+    nonTerminals.add("DT");
+    nonTerminals.add("Vi");
+    nonTerminals.add("Vt");
+    nonTerminals.add("NN");
+    nonTerminals.add("IN");
+    nonTerminals.add("NNP");
+    nonTerminals.add("Adv");
+
+    startSymbol = "S";
+
+    terminals = new LinkedList<String>();
+    terminals.add("works");
+    terminals.add("saw");
+    terminals.add("man");
+    terminals.add("woman");
+    terminals.add("dog");
+    terminals.add("the");
+    terminals.add("with");
+    terminals.add("in");
+    terminals.add("joe");
+    terminals.add("john");
+    terminals.add("sam");
+    terminals.add("michael");
+    terminals.add("michelle");
+    terminals.add("scarlett");
+    terminals.add("and");
+    terminals.add("but");
+    terminals.add("while");
+    terminals.add("of");
+    terminals.add("for");
+    terminals.add("badly");
+    terminals.add("nicely");
+
+    rules = new HashMap<Rule, Double>();
+    rules.put(new Rule("S", "NP", "VP"), 1d);
+    rules.put(new Rule("VP", "Vi", "Adv"), 0.3);
+    rules.put(new Rule("VP", "Vt", "NP"), 0.7);
+    rules.put(new Rule("NP", "DT", "NN"), 1d);
+    rules.put(new Rule("Vi", "works"), 1d);
+    rules.put(new Rule("Vt", "saw"), 1d);
+    rules.put(new Rule("NN", "man"), 0.5);
+    rules.put(new Rule("NN", "woman"), 0.2);
+    rules.put(new Rule("NN", "dog"), 0.3);
+    rules.put(new Rule("DT", "the"), 1d);
+    rules.put(new Rule("IN", "with"), 0.2);
+    rules.put(new Rule("IN", "in"), 0.1);
+    rules.put(new Rule("IN", "for"), 0.4);
+    rules.put(new Rule("IN", "of"), 0.4);
+    rules.put(new Rule("NNP", "joe"), 0.1);
+    rules.put(new Rule("NNP", "john"), 0.1);
+    rules.put(new Rule("NNP", "sam"), 0.1);
+    rules.put(new Rule("NNP", "michael"), 0.1);
+    rules.put(new Rule("NNP", "michelle"), 0.1);
+    rules.put(new Rule("NNP", "scarlett"), 0.5);
+    rules.put(new Rule("Adv", "badly"), 0.3);
+    rules.put(new Rule("Adv", "nicely"), 0.7);
+  }
+
+  @Test
+  public void testIntermediateProbability() throws Exception {
+    ArrayList<String> sentence = new ArrayList<String>();
+    sentence.add("the");
+    sentence.add("dog");
+    sentence.add("saw");
+    sentence.add("the");
+    sentence.add("man");
+    sentence.add("with");
+    sentence.add("the");
+    sentence.add("woman");
+
+    ProbabilisticContextFreeGrammar pcfg = new ProbabilisticContextFreeGrammar(nonTerminals, terminals, rules, startSymbol);
+
+    double pi = pcfg.pi(sentence, 0, 1, pcfg.getStartSymbol()).getProbability();
+    assertTrue(pi <= 1 && pi >= 0);
+
+    pi = pcfg.pi(sentence, 2, 7, "VP").getProbability();
+    assertTrue(pi <= 1 && pi >= 0);
+  }
+
+  @Test
+  public void testFullSentenceCKY() throws Exception {
+    ProbabilisticContextFreeGrammar pcfg = new ProbabilisticContextFreeGrammar(nonTerminals, terminals, rules, startSymbol, true);
+
+    // fixed sentence one
+    List<String> sentence = new ArrayList<String>();
+    sentence.add("the");
+    sentence.add("dog");
+    sentence.add("saw");
+    sentence.add("the");
+    sentence.add("man");
+
+    ProbabilisticContextFreeGrammar.BackPointer backPointer = pcfg.cky(sentence, pcfg);
+    check(pcfg, backPointer, sentence);
+
+    // fixed sentence two
+    sentence = new ArrayList<String>();
+    sentence.add("the");
+    sentence.add("man");
+    sentence.add("works");
+    sentence.add("nicely");
+
+    backPointer = pcfg.cky(sentence, pcfg);
+    check(pcfg, backPointer, sentence);
+
+    // random sentence generated by the grammar
+    String[] expansion = pcfg.leftMostDerivation("S");
+    sentence = Arrays.asList(expansion);
+
+    backPointer = pcfg.cky(sentence, pcfg);
+    check(pcfg, backPointer, sentence);
+  }
+
+  private void check(ProbabilisticContextFreeGrammar pcfg, ProbabilisticContextFreeGrammar.BackPointer backPointer, List<String> sentence) {
+    Rule rule = backPointer.getRule();
+    assertNotNull(rule);
+    assertEquals(pcfg.getStartSymbol(), rule.getEntry());
+    int s = backPointer.getSplitPoint();
+    assertTrue(s >= 0);
+    double pi = backPointer.getProbability();
+    assertTrue(pi <= 1 && pi >= 0);
+    List<String> expandedTerminals = getTerminals(backPointer);
+    for (int i = 0; i < sentence.size(); i++) {
+      assertEquals(sentence.get(i), expandedTerminals.get(i));
+    }
+
+  }
+
+  private List<String> getTerminals(ProbabilisticContextFreeGrammar.BackPointer backPointer) {
+    if (backPointer.getLeftTree() == null && backPointer.getRightTree() == null) {
+      return Arrays.asList(backPointer.getRule().getExpansion());
+    }
+
+    ArrayList<String> list = new ArrayList<String>();
+    list.addAll(getTerminals(backPointer.getLeftTree()));
+    list.addAll(getTerminals(backPointer.getRightTree()));
+    return list;
+  }
+
+}
diff --git a/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java
index 40a2bd7..9c43087 100644
--- a/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java
+++ b/nlp-utils/src/test/java/org/apache/opennlp/utils/cfg/RuleTest.java
@@ -28,8 +28,16 @@
  * Testcase for {@link Rule}
  */
 public class RuleTest {
+
   @Test
   public void testEquals() throws Exception {
+    Rule r1 = new Rule("NP", "NP", "PP");
+    Rule r2 = new Rule("NP", "NP", "PP");
+    assertEquals(r1, r2);
+  }
+
+  @Test
+  public void testNotEquals() throws Exception {
     Rule r1 = new Rule("NP", "DT", "NN");
     Rule r2 = new Rule("NP", "NP", "PP");
     assertNotEquals(r1, r2);