blob: 04f987ae0a0cf4bdb58b22cab01da6a3abf25b18 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.client.solrj.io.eval;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Explanation.ExpressionType;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionValue;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
public abstract class RecursiveEvaluator implements StreamEvaluator, ValueWorker {
protected static final long serialVersionUID = 1L;
protected StreamContext streamContext;
protected UUID nodeId = UUID.randomUUID();
protected StreamFactory constructingFactory;
protected List<StreamEvaluator> containedEvaluators = new ArrayList<StreamEvaluator>();
public RecursiveEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
this(expression, factory, new ArrayList<>());
}
protected Object normalizeInputType(Object value){
if(null == value){
return null;
} else if (value instanceof VectorFunction) {
return value;
}
else if(value instanceof Double){
if(Double.isNaN((Double)value)){
return null;
}
return new BigDecimal(value.toString());
}
else if(value instanceof BigDecimal){
return (BigDecimal)value;
}
else if(value instanceof Number){
return new BigDecimal(value.toString());
}
else if(value instanceof Collection){
//Let's first check to see if we have a List of Strings.
//If we do let's try and convert to a list of doubles and see what happens
try {
List<Number> vector = new ArrayList<>();
boolean allDoubles = true;
for(Object o : (Collection)value) {
if(o instanceof String) {
Double d = Double.parseDouble(o.toString());
vector.add(d);
} else {
allDoubles = false;
break;
}
}
if(allDoubles) {
return vector;
}
} catch(Exception e) {
}
return ((Collection<?>)value).stream().map(innerValue -> normalizeInputType(innerValue)).collect(Collectors.toList());
}
else if(value.getClass().isArray()){
Stream<?> stream = Stream.empty();
if(value instanceof double[]){
stream = Arrays.stream((double[])value).boxed();
}
else if(value instanceof int[]){
stream = Arrays.stream((int[])value).boxed();
}
else if(value instanceof long[]){
stream = Arrays.stream((long[])value).boxed();
}
else if(value instanceof String[]){
stream = Arrays.stream((String[])value);
}
return stream.map(innerValue -> normalizeInputType(innerValue)).collect(Collectors.toList());
}
else{
// anything else can just be returned as is
return value;
}
}
protected Object normalizeOutputType(Object value) {
if(null == value){
return null;
} else if (value instanceof VectorFunction) {
return value;
} else if(value instanceof BigDecimal){
BigDecimal bd = (BigDecimal)value;
return bd.doubleValue();
}
else if(value instanceof Long || value instanceof Integer) {
return ((Number) value).longValue();
}
else if(value instanceof Double){
return value;
}
else if(value instanceof Number){
return ((Number) value).doubleValue();
}
else if(value instanceof List){
// normalize each value in the list
return ((List<?>)value).stream().map(innerValue -> normalizeOutputType(innerValue)).collect(Collectors.toList());
} else if(value instanceof Tuple && value.getClass().getEnclosingClass() == null) {
//If its a tuple and not a inner class that has extended tuple, which is done in a number of cases so that mathematical models
//can be contained within a tuple.
Tuple tuple = (Tuple)value;
Tuple newTuple = new Tuple();
for(Object o : tuple.getFields().keySet()) {
Object v = tuple.get(o);
newTuple.put(o, normalizeOutputType(v));
}
return newTuple;
}
else{
// anything else can just be returned as is
return value;
}
}
public RecursiveEvaluator(StreamExpression expression, StreamFactory factory, List<String> ignoredNamedParameters) throws IOException{
this.constructingFactory = factory;
// We have to do this because order of the parameters matter
List<StreamExpressionParameter> parameters = factory.getOperandsOfType(expression, StreamExpressionParameter.class);
for(StreamExpressionParameter parameter : parameters){
if(parameter instanceof StreamExpression){
// possible evaluator
StreamExpression streamExpression = (StreamExpression)parameter;
if(factory.doesRepresentTypes(streamExpression, RecursiveEvaluator.class)){
containedEvaluators.add(factory.constructEvaluator(streamExpression));
}
else if(factory.doesRepresentTypes(streamExpression, SourceEvaluator.class)){
containedEvaluators.add(factory.constructEvaluator(streamExpression));
}
else{
// Will be treated as a field name
containedEvaluators.add(new FieldValueEvaluator(streamExpression.toString()));
}
}
else if(parameter instanceof StreamExpressionValue){
if(0 != ((StreamExpressionValue)parameter).getValue().length()){
// special case - if evaluates to a number, boolean, or null then we'll treat it
// as a RawValueEvaluator
Object value = factory.constructPrimitiveObject(((StreamExpressionValue)parameter).getValue());
if(null == value || value instanceof Boolean || value instanceof Number){
containedEvaluators.add(new RawValueEvaluator(value));
}
else if(value instanceof String){
containedEvaluators.add(new FieldValueEvaluator((String)value));
}
}
}
}
Set<String> namedParameters = factory.getNamedOperands(expression).stream().map(param -> param.getName()).collect(Collectors.toSet());
long ignorableCount = ignoredNamedParameters.stream().filter(name -> namedParameters.contains(name)).count();
/*
if(0 != expression.getParameters().size() - containedEvaluators.size() - ignorableCount){
if(namedParameters.isEmpty()){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found - expecting only StreamEvaluators or field names", expression));
}
else{
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - unknown operands found - expecting only StreamEvaluators, field names, or named parameters [%s]", expression, namedParameters.stream().collect(Collectors.joining(","))));
}
}
*/
}
@Override
public Object evaluate(Tuple tuple) throws IOException {
try{
List<Object> containedResults = recursivelyEvaluate(tuple);
// this needs to be treated as an array of objects when going into doWork(Object ... values)
return normalizeOutputType(doWork(containedResults.toArray()));
}
catch(UncheckedIOException e){
throw e.getCause();
}
}
public List<Object> recursivelyEvaluate(Tuple tuple) throws IOException {
List<Object> results = new ArrayList<>();
try{
for(StreamEvaluator containedEvaluator : containedEvaluators){
results.add(normalizeInputType(containedEvaluator.evaluate(tuple)));
}
}
catch(StreamEvaluatorException e){
throw new IOException(String.format(Locale.ROOT, "Failed to evaluate expression %s - %s", toExpression(constructingFactory), e.getMessage()), e);
}
return results;
}
@Override
public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
StreamExpression expression = new StreamExpression(factory.getFunctionName(getClass()));
for(StreamEvaluator evaluator : containedEvaluators){
expression.addParameter(evaluator.toExpression(factory));
}
return expression;
}
@Override
public Explanation toExplanation(StreamFactory factory) throws IOException {
return new Explanation(nodeId.toString())
.withExpressionType(ExpressionType.EVALUATOR)
.withFunctionName(factory.getFunctionName(getClass()))
.withImplementingClass(getClass().getName())
.withExpression(toExpression(factory).toString());
}
public void setStreamContext(StreamContext context) {
this.streamContext = context;
for(StreamEvaluator containedEvaluator : containedEvaluators){
containedEvaluator.setStreamContext(context);
}
}
public StreamContext getStreamContext(){
return streamContext;
}
}