| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.opennlp.utils.cfg; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collection; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.LinkedList; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.Random; |
| import java.util.regex.Matcher; |
| import java.util.regex.Pattern; |
| |
| /** |
| * a probabilistic CFG |
| */ |
| public class ProbabilisticContextFreeGrammar { |
| |
| private final Collection<String> nonTerminalSymbols; |
| private final Collection<String> terminalSymbols; |
| private final Map<Rule, Double> rules; |
| private final String startSymbol; |
| private boolean randomExpansion; |
| |
| private static final Rule emptyRule = new Rule("E", ""); |
| |
| private static final String nonTerminalMatcher = "[\\w\\~\\*\\-\\.\\,\\'\\:\\_\\\"]"; |
| private static final String terminalMatcher = "[òàùìèé\\|\\w\\'\\.\\,\\:\\_Ù\\?È\\%\\;À\\-\\\"]"; |
| |
| private static final Pattern terminalPattern = Pattern.compile("\\(("+nonTerminalMatcher+"+)\\s("+terminalMatcher+"+)\\)"); |
| private static final Pattern nonTerminalPattern = Pattern.compile( |
| "\\(("+nonTerminalMatcher+"+)" + // source NT |
| "\\s("+nonTerminalMatcher+"+)(\\s("+nonTerminalMatcher+"+))*\\)" // expansion NTs |
| ); |
| |
| public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, |
| Map<Rule, Double> rules, String startSymbol, boolean randomExpansion) { |
| |
| assert nonTerminalSymbols.contains(startSymbol) : "start symbol doesn't belong to non-terminal symbols set"; |
| |
| this.nonTerminalSymbols = nonTerminalSymbols; |
| this.terminalSymbols = terminalSymbols; |
| this.rules = rules; |
| this.startSymbol = startSymbol; |
| this.randomExpansion = randomExpansion; |
| } |
| |
| public ProbabilisticContextFreeGrammar(Collection<String> nonTerminalSymbols, Collection<String> terminalSymbols, Map<Rule, Double> rules, String startSymbol) { |
| this(nonTerminalSymbols, terminalSymbols, rules, startSymbol, false); |
| } |
| |
| public Collection<String> getNonTerminalSymbols() { |
| return nonTerminalSymbols; |
| } |
| |
| public Collection<String> getTerminalSymbols() { |
| return terminalSymbols; |
| } |
| |
| public Map<Rule, Double> getRules() { |
| return rules; |
| } |
| |
| public String getStartSymbol() { |
| return startSymbol; |
| } |
| |
| |
| public String[] leftMostDerivation(String... words) { |
| ArrayList<String> expansion = new ArrayList<String>(words.length); |
| |
| assert words.length > 0 && startSymbol.equals(words[0]); |
| |
| for (String word : words) { |
| expansion.addAll(getTerminals(word)); |
| } |
| return expansion.toArray(new String[expansion.size()]); |
| |
| } |
| |
| private Collection<String> getTerminals(String word) { |
| if (terminalSymbols.contains(word)) { |
| Collection<String> c = new LinkedList<String>(); |
| c.add(word); |
| return c; |
| } else { |
| assert nonTerminalSymbols.contains(word) : "word " + word + " is not contained in non terminals"; |
| String[] expansions = getExpansionForSymbol(word); |
| Collection<String> c = new LinkedList<String>(); |
| for (String e : expansions) { |
| c.addAll(getTerminals(e)); |
| } |
| return c; |
| } |
| } |
| |
| private String[] getExpansionForSymbol(String currentSymbol) { |
| Rule r = getRuleForSymbol(currentSymbol); |
| return r.getExpansion(); |
| } |
| |
| private Rule getRuleForSymbol(String word) { |
| ArrayList<Rule> possibleRules = new ArrayList<Rule>(); |
| for (Rule r : rules.keySet()) { |
| if (word.equals(r.getEntry())) { |
| if (!randomExpansion) { |
| return r; |
| } |
| possibleRules.add(r); |
| } |
| } |
| if (possibleRules.size() > 0) { |
| return possibleRules.get(new Random().nextInt(possibleRules.size())); |
| } else { |
| throw new RuntimeException("could not find a rule for expanding symbol " + word); |
| } |
| } |
| |
| public ParseTree pi(List<String> sentence, int i, int j, String x) { |
| ParseTree parseTree = new ParseTree(0, 0, null); |
| if (i == j) { |
| Rule rule = new Rule(x, sentence.get(i)); |
| double q = q(rule); |
| parseTree = new ParseTree(q, i, rule); |
| } else { |
| double max = 0; |
| for (Rule rule : getNTRules()) { |
| for (int s = i; s < j; s++) { |
| double q = q(rule); |
| ParseTree left = pi(sentence, i, s, rule.getExpansion()[0]); |
| ParseTree right = pi(sentence, s + 1, j, rule.getExpansion()[1]); |
| double cp = q * left.getProbability() * right.getProbability(); |
| if (cp > max) { |
| max = cp; |
| parseTree = new ParseTree(max, s, rule, left, right); |
| } |
| } |
| } |
| } |
| return parseTree; |
| } |
| |
| public ParseTree cky(List<String> sentence) { |
| ParseTree parseTree = null; |
| |
| int n = sentence.size(); |
| for (int l = 1; l < n; l++) { |
| for (int i = 0; i < n - l; i++) { |
| int j = i + l; |
| double max = 0; |
| for (String x : getNonTerminalSymbols()) { |
| for (Rule r : getRulesForNonTerminal(x)) { |
| for (int s = i; s < j - 1; s++) { |
| double q = q(r); |
| ParseTree left = pi(sentence, i, s, r.getExpansion()[0]); |
| ParseTree right = pi(sentence, s + 1, j, r.getExpansion()[1]); |
| double cp = q * left.getProbability() * right.getProbability(); |
| if (cp > max) { |
| max = cp; |
| parseTree = new ParseTree(max, s, r, left, right); |
| } |
| } |
| } |
| } |
| } |
| } |
| return parseTree; |
| } |
| |
| private Collection<Rule> getRulesForNonTerminal(String x) { |
| LinkedList<Rule> ntRules = new LinkedList<Rule>(); |
| for (Rule r : rules.keySet()) { |
| String[] expansion = r.getExpansion(); |
| if (expansion.length == 2 && x.equals(r.getEntry()) && nonTerminalSymbols.contains(expansion[0]) && nonTerminalSymbols.contains(expansion[1])) { |
| ntRules.add(r); |
| } |
| } |
| return ntRules; |
| } |
| |
| private Collection<Rule> getNTRules() { |
| Collection<Rule> ntRules = new LinkedList<Rule>(); |
| for (Rule r : rules.keySet()) { |
| String[] expansion = r.getExpansion(); |
| if (expansion.length == 2 && nonTerminalSymbols.contains(expansion[0]) && nonTerminalSymbols.contains(expansion[1])) { |
| ntRules.add(r); |
| } |
| } |
| return ntRules; |
| } |
| |
| private double q(Rule rule) { |
| return rules.keySet().contains(rule) ? rules.get(rule) : 0; |
| } |
| |
| public class ParseTree { |
| |
| private final double probability; |
| private final int splitPoint; |
| private final Rule rule; |
| private ParseTree leftTree; |
| private ParseTree rightTree; |
| |
| private ParseTree(double probability, int splitPoint, Rule rule) { |
| this.probability = probability; |
| this.splitPoint = splitPoint; |
| this.rule = rule; |
| } |
| |
| public ParseTree(double probability, int splitPoint, Rule rule, ParseTree leftTree, ParseTree rightTree) { |
| this.probability = probability; |
| this.splitPoint = splitPoint; |
| this.rule = rule; |
| this.leftTree = leftTree; |
| this.rightTree = rightTree; |
| } |
| |
| public double getProbability() { |
| return probability; |
| } |
| |
| public int getSplitPoint() { |
| return splitPoint; |
| } |
| |
| public Rule getRule() { |
| return rule; |
| } |
| |
| public ParseTree getLeftTree() { |
| return leftTree; |
| } |
| |
| public ParseTree getRightTree() { |
| return rightTree; |
| } |
| |
| @Override |
| public String toString() { |
| if (getRule() != emptyRule) { |
| return "(" + |
| rule.getEntry() + " " + |
| (leftTree != null && rightTree != null ? |
| leftTree.toString() + " " + rightTree.toString() : |
| rule.getExpansion()[0] |
| ) + |
| ')'; |
| } else { |
| return ""; |
| } |
| } |
| |
| } |
| |
| public static Map<Rule, Double> parseRules(String... parseTreeString) { |
| Map<Rule, Double> rules = new HashMap<>(); |
| parseRules(rules, false, parseTreeString); |
| return rules; |
| } |
| |
| public static void parseRules(Map<Rule, Double> rules, boolean trim, String... parseStrings) { |
| parseGrammar(rules, "S", trim, parseStrings); |
| } |
| |
| public static ProbabilisticContextFreeGrammar parseGrammar(boolean trim, String... parseTreeStrings) { |
| return parseGrammar(new HashMap<Rule, Double>(), "S", trim, parseTreeStrings); |
| } |
| |
| public static ProbabilisticContextFreeGrammar parseGrammar(String... parseTreeStrings) { |
| return parseGrammar(new HashMap<Rule, Double>(), "S", true, parseTreeStrings); |
| } |
| |
| public static ProbabilisticContextFreeGrammar parseGrammar(Map<Rule, Double> rulesMap, String startSymbol, boolean trim, String... parseStrings) { |
| |
| Map<Rule, Double> rules = new HashMap<>(); |
| |
| Collection<String> nonTerminals = new HashSet<>(); |
| Collection<String> terminals = new HashSet<>(); |
| |
| for (String parseTreeString : parseStrings) { |
| |
| if (trim) { |
| parseTreeString = parseTreeString.replaceAll("\n", "").replaceAll("\t", "").replaceAll("\\s+", " "); |
| } |
| |
| String toConsume = String.valueOf(parseTreeString); |
| |
| Matcher m = terminalPattern.matcher(parseTreeString); |
| while (m.find()) { |
| String nt = m.group(1); |
| String t = m.group(2); |
| Rule key = new Rule(nt, t); |
| if (!rules.containsKey(key)) { |
| rules.put(key, 1d); |
| terminals.add(t); |
| // System.err.println(key); |
| } |
| toConsume = toConsume.replace(m.group(), nt); |
| } |
| |
| while (toConsume.contains(" ") && !toConsume.trim().equals("( " + startSymbol + " )")) { |
| Matcher m2 = nonTerminalPattern.matcher(toConsume); |
| while (m2.find()) { |
| String nt = m2.group(1); |
| String t1 = m2.group(2); |
| String t2 = m2.group(3); |
| |
| Rule key; |
| if (t2 != null) { |
| String[] t2s = t2.trim().split(" "); |
| String[] nts = new String[t2s.length + 1]; |
| nts[0] = t1; |
| System.arraycopy(t2s, 0, nts, 1, t2s.length); |
| key = new Rule(nt, nts); |
| nonTerminals.addAll(Arrays.asList(nts)); |
| } else { |
| key = new Rule(nt, t1); |
| nonTerminals.add(t1); |
| } |
| nonTerminals.add(key.getEntry()); |
| |
| if (!rules.containsKey(key)) { |
| rules.put(key, 1d); |
| // startSymbol = key.getEntry(); |
| // System.err.println(key); |
| } |
| toConsume = toConsume.replace(m2.group(), nt); |
| } |
| } |
| } |
| |
| // TODO : check/adjust rules to make them respect CNF |
| // TODO : adjust probabilities based on term frequencies |
| for (Map.Entry<Rule, Double> entry : rules.entrySet()) { |
| normalize(entry.getKey(), nonTerminals, terminals, rulesMap); |
| } |
| |
| return new ProbabilisticContextFreeGrammar(nonTerminals, terminals, rulesMap, startSymbol, true); |
| } |
| |
| private static void normalize(Rule rule, Collection<String> nonTerminals, Collection<String> terminals, Map<Rule, Double> rulesMap) { |
| String[] expansion = rule.getExpansion(); |
| if (expansion.length == 1) { |
| if (!terminals.contains(expansion[0])) { |
| if (nonTerminals.contains(expansion[0])) { |
| // nt1 -> nt2 should be expanded in nt1 -> nt2,E |
| rulesMap.put(new Rule(rule.getEntry(), expansion[0], emptyRule.getEntry()), 1d); |
| if (rulesMap.containsKey(emptyRule)) { |
| rulesMap.put(emptyRule, 1d); |
| } |
| } else { |
| throw new RuntimeException("rule "+rule+" expands to neither a terminal or non terminal"); |
| } |
| } else { |
| rulesMap.put(rule, 1d); |
| } |
| } else if (expansion.length > 2){ |
| // nt1 -> nt2,nt3,...,ntn should be collapsed to a hierarchy of ntX -> ntY,ntZ rules |
| String nt2 = expansion[0]; |
| int seed = nonTerminals.size(); |
| String generatedNT = "GEN~" + seed; |
| nonTerminals.add(generatedNT); |
| Rule newRule = new Rule(rule.getEntry(), nt2, generatedNT); |
| rulesMap.put(newRule, 1d); |
| Rule chainedRule = new Rule(generatedNT, Arrays.copyOfRange(expansion, 1, expansion.length - 1)); |
| rulesMap.put(chainedRule, 1d); |
| normalize(chainedRule, nonTerminals, terminals, rulesMap); |
| } else { |
| rulesMap.put(rule, 1d); |
| } |
| } |
| } |