blob: 0c5542e57210fc30c5aa036c1c05a77bc9dd4133 [file] [log] [blame]
package org.apache.rya.rdftriplestore.inference;
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.util.HashSet;
import java.util.Set;
import java.util.Stack;
import org.apache.rya.accumulo.AccumuloRdfConfiguration;
import org.apache.rya.rdftriplestore.utils.FixedStatementPattern;
import org.eclipse.rdf4j.model.IRI;
import org.eclipse.rdf4j.model.Resource;
import org.eclipse.rdf4j.model.Statement;
import org.eclipse.rdf4j.model.ValueFactory;
import org.eclipse.rdf4j.model.impl.SimpleValueFactory;
import org.eclipse.rdf4j.model.vocabulary.RDF;
import org.eclipse.rdf4j.model.vocabulary.RDFS;
import org.eclipse.rdf4j.query.algebra.Join;
import org.eclipse.rdf4j.query.algebra.Projection;
import org.eclipse.rdf4j.query.algebra.ProjectionElem;
import org.eclipse.rdf4j.query.algebra.ProjectionElemList;
import org.eclipse.rdf4j.query.algebra.StatementPattern;
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.Union;
import org.eclipse.rdf4j.query.algebra.Var;
import org.junit.Assert;
import org.junit.Test;
public class DomainRangeVisitorTest {
private static final AccumuloRdfConfiguration conf = new AccumuloRdfConfiguration();
private static final ValueFactory VF = SimpleValueFactory.getInstance();
private static final IRI person = VF.createIRI("lubm:Person");
private static final IRI advisor = VF.createIRI("lubm:advisor");
private static final IRI takesCourse = VF.createIRI("lubm:takesCourse");
@Test
public void testRewriteTypePattern() throws Exception {
final InferenceEngine inferenceEngine = mock(InferenceEngine.class);
final Set<IRI> domainPredicates = new HashSet<>();
final Set<IRI> rangePredicates = new HashSet<>();
domainPredicates.add(advisor);
domainPredicates.add(takesCourse);
rangePredicates.add(advisor);
when(inferenceEngine.getPropertiesWithDomain(person)).thenReturn(domainPredicates);
when(inferenceEngine.getPropertiesWithRange(person)).thenReturn(rangePredicates);
final Var subjVar = new Var("s");
final StatementPattern originalSP = new StatementPattern(subjVar, new Var("p", RDF.TYPE), new Var("o", person));
final Projection query = new Projection(originalSP, new ProjectionElemList(new ProjectionElem("s", "subject")));
query.visit(new DomainRangeVisitor(conf, inferenceEngine));
// Resulting tree should consist of Unions of:
// 1. The original StatementPattern
// 2. A join checking for domain inference
// 3. A join checking for range inference
boolean containsOriginal = false;
boolean containsDomain = false;
boolean containsRange = false;
final Stack<TupleExpr> nodes = new Stack<>();
nodes.push(query.getArg());
while (!nodes.isEmpty()) {
final TupleExpr currentNode = nodes.pop();
if (currentNode instanceof Union) {
nodes.push(((Union) currentNode).getLeftArg());
nodes.push(((Union) currentNode).getRightArg());
}
else if (currentNode instanceof StatementPattern) {
Assert.assertFalse(containsOriginal);
Assert.assertEquals(originalSP, currentNode);
containsOriginal = true;
}
else if (currentNode instanceof Join) {
final TupleExpr left = ((Join) currentNode).getLeftArg();
final TupleExpr right = ((Join) currentNode).getRightArg();
Assert.assertTrue("Left-hand side should enumerate domain/range predicates",
left instanceof FixedStatementPattern);
Assert.assertTrue("Right-hand side should be a non-expandable SP matching triples satisfying domain/range",
right instanceof DoNotExpandSP);
final FixedStatementPattern fsp = (FixedStatementPattern) left;
final StatementPattern sp = (StatementPattern) right;
// fsp should be <predicate var, domain/range, original type var>
boolean isDomain = RDFS.DOMAIN.equals(fsp.getPredicateVar().getValue());
boolean isRange = RDFS.RANGE.equals(fsp.getPredicateVar().getValue());
Assert.assertTrue(isDomain || isRange);
Assert.assertEquals(originalSP.getObjectVar(), fsp.getObjectVar());
// sp should have same predicate var
Assert.assertEquals(fsp.getSubjectVar(), sp.getPredicateVar());
// collect predicates that are enumerated in the FixedStatementPattern
final Set<Resource> queryPredicates = new HashSet<>();
for (Statement statement : fsp.statements) {
queryPredicates.add(statement.getSubject());
}
if (isDomain) {
Assert.assertFalse(containsDomain);
// sp should be <original subject var, predicate var, unbound object var> for domain
Assert.assertEquals(originalSP.getSubjectVar(), sp.getSubjectVar());
Assert.assertFalse(sp.getObjectVar().hasValue());
// verify predicates
Assert.assertEquals(2, fsp.statements.size());
Assert.assertEquals(domainPredicates, queryPredicates);
Assert.assertTrue(queryPredicates.contains(advisor));
Assert.assertTrue(queryPredicates.contains(takesCourse));
containsDomain = true;
}
else {
Assert.assertFalse(containsRange);
// sp should be <unbound subject var, predicate var, original subject var> for range
Assert.assertFalse(sp.getSubjectVar().hasValue());
Assert.assertEquals(originalSP.getSubjectVar(), sp.getObjectVar());
// verify predicates
Assert.assertEquals(1, fsp.statements.size());
Assert.assertEquals(rangePredicates, queryPredicates);
Assert.assertTrue(queryPredicates.contains(advisor));
containsRange = true;
}
}
else {
Assert.fail("Expected nested Unions of Joins and StatementPatterns, found: " + currentNode.toString());
}
}
Assert.assertTrue(containsOriginal && containsDomain && containsRange);
}
}