blob: 413186bdefb5cc7ae4db845e2819dba3940b1cb2 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* License); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an AS IS BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import Logging
import Foundation
struct BundleProcessor {
let log: Logging.Logger
struct Step {
let transformId: String
let fn: SerializableFn
let inputs: [AnyPCollectionStream]
let outputs: [AnyPCollectionStream]
let payload: Data
}
let steps: [Step]
init(id: String,
descriptor: Org_Apache_Beam_Model_FnExecution_V1_ProcessBundleDescriptor,
collections: [String: AnyPCollection],
fns: [String: SerializableFn]) throws
{
log = Logging.Logger(label: "BundleProcessor(\(id) \(descriptor.id))")
var temp: [Step] = []
var coders = BundleCoderContainer(bundle: descriptor)
var streams: [String: AnyPCollectionStream] = [:]
// First make streams for everything in this bundle (maybe I could use the pcollection array for this?)
for (_, transform) in descriptor.transforms {
for id in transform.inputs.values {
if streams[id] == nil {
streams[id] = collections[id]!.anyStream
}
}
for id in transform.outputs.values {
if streams[id] == nil {
streams[id] = collections[id]!.anyStream
}
}
}
for (transformId, transform) in descriptor.transforms {
let urn = transform.spec.urn
// Map the input and output streams in the correct order
let inputs = transform.inputs.sorted().map { streams[$0.1]! }
let outputs = transform.outputs.sorted().map { streams[$0.1]! }
if urn == "beam:runner:source:v1" {
let remotePort = try RemoteGrpcPort(serializedData: transform.spec.payload)
let coder = try Coder.of(name: remotePort.coderID, in: coders)
log.info("Source '\(transformId)','\(transform.uniqueName)' \(remotePort) \(coder)")
try temp.append(Step(
transformId: transform.uniqueName == "" ? transformId : transform.uniqueName,
fn: Source(client: .client(for: ApiServiceDescriptor(proto: remotePort.apiServiceDescriptor), worker: id), coder: coder),
inputs: inputs,
outputs: outputs,
payload: Data()
))
} else if urn == "beam:runner:sink:v1" {
let remotePort = try RemoteGrpcPort(serializedData: transform.spec.payload)
let coder = try Coder.of(name: remotePort.coderID, in: coders)
log.info("Sink '\(transformId)','\(transform.uniqueName)' \(remotePort) \(coder)")
try temp.append(Step(
transformId: transform.uniqueName == "" ? transformId : transform.uniqueName,
fn: Sink(client: .client(for: ApiServiceDescriptor(proto: remotePort.apiServiceDescriptor), worker: id), coder: coder),
inputs: inputs,
outputs: outputs,
payload: Data()
))
} else if urn == "beam:transform:pardo:v1" {
let pardoPayload = try Org_Apache_Beam_Model_Pipeline_V1_ParDoPayload(serializedData: transform.spec.payload)
if let fn = fns[transform.uniqueName] {
temp.append(Step(transformId: transform.uniqueName,
fn: fn,
inputs: inputs,
outputs: outputs,
payload: pardoPayload.doFn.payload))
} else {
log.warning("Unable to map \(transform.uniqueName) to a known SerializableFn. Will be skipped during processing.")
}
} else {
log.warning("Unable to map \(urn). Will be skipped during processing.")
}
}
steps = temp
}
public func process(instruction: String, responder: AsyncStream<Org_Apache_Beam_Model_FnExecution_V1_InstructionResponse>.Continuation) async {
_ = await withThrowingTaskGroup(of: (String, String).self) { group in
log.info("Starting bundle processing for \(instruction)")
var count: Int = 0
do {
for step in steps {
log.info("Starting Task \(step.transformId)")
let context = SerializableFnBundleContext(instruction: instruction, transform: step.transformId, payload: step.payload, log: log)
group.addTask {
try await step.fn.process(context: context, inputs: step.inputs, outputs: step.outputs)
}
count += 1
}
var finished = 0
for try await (instruction, transform) in group {
finished += 1
log.info("Task Completed (\(instruction),\(transform)) \(finished) of \(count)")
}
log.info("All tasks completed for \(instruction)")
responder.yield(.with {
$0.instructionID = instruction
})
} catch {
responder.yield(.with {
$0.instructionID = instruction
$0.error = "\(error)"
})
}
}
}
}