blob: a851034dd8db1c3e73f9ee73f509e035a94f3593 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.schemas.transforms;
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.annotations.Experimental.Kind;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
import org.apache.commons.compress.utils.Lists;
/**
* A transform for renaming fields inside an existing schema. Top level or nested fields can be
* renamed. When renaming a nested field, the nested prefix does not need to be specified again when
* specifying the new name.
*
* <p>Example use:
*
* <pre>{@code PCollection<Event> events = readEvents();
* PCollection<Row> renamedEvents =
* events.apply(RenameFields.<Event>create()
* .rename("userName", "userId")
* .rename("location.country", "countryCode"));
* }</pre>
*/
@Experimental(Kind.SCHEMAS)
public class RenameFields {
/** Create an instance of this transform. */
public static <T> Inner<T> create() {
return new Inner<>();
}
// Describes a single renameSchema rule.
private static class RenamePair implements Serializable {
// The FieldAccessDescriptor describing the field to renameSchema. Must reference a singleton
// field.
private final FieldAccessDescriptor fieldAccessDescriptor;
// The new name for the field.
private final String newName;
RenamePair(FieldAccessDescriptor fieldAccessDescriptor, String newName) {
this.fieldAccessDescriptor = fieldAccessDescriptor;
this.newName = newName;
}
RenamePair resolve(Schema schema) {
FieldAccessDescriptor resolved = fieldAccessDescriptor.resolve(schema);
if (!resolved.referencesSingleField()) {
throw new IllegalArgumentException(resolved + " references multiple fields.");
}
return new RenamePair(resolved, newName);
}
}
private static FieldType renameFieldType(FieldType inputType, Collection<RenamePair> renames) {
switch (inputType.getTypeName()) {
case ROW:
return FieldType.row(renameSchema(inputType.getRowSchema(), renames));
case ARRAY:
return FieldType.array(renameFieldType(inputType.getCollectionElementType(), renames));
case MAP:
return FieldType.map(
renameFieldType(inputType.getMapKeyType(), renames),
renameFieldType(inputType.getMapValueType(), renames));
default:
return inputType;
}
}
// Apply the user-specified renames to the input schema.
private static Schema renameSchema(Schema inputSchema, Collection<RenamePair> renames) {
// The mapping of renames to apply at this level of the schema.
Map<Integer, String> topLevelRenames = Maps.newHashMap();
// For nested schemas, collect all applicable renames here.
Multimap<Integer, RenamePair> nestedRenames = ArrayListMultimap.create();
for (RenamePair rename : renames) {
FieldAccessDescriptor access = rename.fieldAccessDescriptor;
if (!access.fieldIdsAccessed().isEmpty()) {
// This references a field at this level of the schema.
Integer fieldId = Iterables.getOnlyElement(access.fieldIdsAccessed());
topLevelRenames.put(fieldId, rename.newName);
} else {
// This references a nested field.
Map.Entry<Integer, FieldAccessDescriptor> nestedAccess =
Iterables.getOnlyElement(access.nestedFieldsById().entrySet());
nestedRenames.put(
nestedAccess.getKey(), new RenamePair(nestedAccess.getValue(), rename.newName));
}
}
Schema.Builder builder = Schema.builder();
for (int i = 0; i < inputSchema.getFieldCount(); ++i) {
Field field = inputSchema.getField(i);
FieldType fieldType = field.getType();
String newName = topLevelRenames.getOrDefault(i, field.getName());
Collection<RenamePair> nestedFieldRenames = nestedRenames.asMap().get(i);
if (nestedFieldRenames != null) {
// There are nested field renames. Recursively renameSchema the rest of the schema.
builder.addField(newName, renameFieldType(fieldType, nestedFieldRenames));
} else {
// No renameSchema for this field. Just add it back as is, potentially with a new name.
builder.addField(newName, fieldType);
}
}
return builder.build();
}
/** The class implementing the actual PTransform. */
public static class Inner<T> extends PTransform<PCollection<T>, PCollection<Row>> {
private List<RenamePair> renames;
private Inner() {
renames = Lists.newArrayList();
}
private Inner(List<RenamePair> renames) {
this.renames = renames;
}
/** Rename a specific field. */
public Inner<T> rename(String field, String newName) {
return rename(FieldAccessDescriptor.withFieldNames(field), newName);
}
/** Rename a specific field. */
public Inner<T> rename(FieldAccessDescriptor field, String newName) {
List<RenamePair> newList =
ImmutableList.<RenamePair>builder()
.addAll(renames)
.add(new RenamePair(field, newName))
.build();
return new Inner<>(newList);
}
@Override
public PCollection<Row> expand(PCollection<T> input) {
Schema inputSchema = input.getSchema();
List<RenamePair> pairs =
renames.stream().map(r -> r.resolve(inputSchema)).collect(Collectors.toList());
final Schema outputSchema = renameSchema(inputSchema, pairs);
return input
.apply(
ParDo.of(
new DoFn<T, Row>() {
@ProcessElement
public void processElement(@Element Row row, OutputReceiver<Row> o) {
o.output(Row.withSchema(outputSchema).attachValues(row.getValues()).build());
}
}))
.setRowSchema(outputSchema);
}
}
}