| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /** |
| * TVM JS Wasm Runtime library. |
| */ |
| import { Pointer, PtrOffset, SizeOf, ArgTypeCode } from "./ctypes"; |
| import { Disposable } from "./types"; |
| import { Memory, CachedCallStack } from "./memory"; |
| import { assert, StringToUint8Array } from "./support"; |
| import { Environment } from "./environment"; |
| import { WebGPUContext } from "./webgpu"; |
| |
| import * as compact from "./compact"; |
| import * as ctypes from "./ctypes"; |
| |
| /** |
| * Type for PackedFunc inthe TVMRuntime. |
| */ |
| export type PackedFunc = ((...args: any) => any) & |
| Disposable & { _tvmPackedCell: PackedFuncCell }; |
| |
| /** |
| * @internal |
| * FFI Library wrapper, maintains most runtime states. |
| */ |
| class FFILibrary implements Disposable { |
| wasm32: boolean; |
| memory: Memory; |
| exports: Record<string, Function>; |
| webGPUContext?: WebGPUContext; |
| private wasmInstance: WebAssembly.Instance; |
| private recycledCallStacks: Array<CachedCallStack> = []; |
| |
| constructor( |
| wasmInstance: WebAssembly.Instance, |
| imports: Record<string, any> |
| ) { |
| this.wasmInstance = wasmInstance; |
| this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); |
| assert( |
| this.wasmInstance.exports !== undefined, |
| "Expect the library module contains exports" |
| ); |
| this.exports = this.wasmInstance.exports as Record<string, Function>; |
| this.wasm32 = this.memory.wasm32; |
| this.validateInstance(); |
| } |
| |
| dispose(): void { |
| while (this.recycledCallStacks.length != 0) { |
| (this.recycledCallStacks.pop() as Disposable).dispose(); |
| } |
| } |
| |
| sizeofPtr(): number { |
| return this.memory.sizeofPtr(); |
| } |
| |
| checkCall(code: number): void { |
| if (code != 0) { |
| const msgPtr = (this.exports |
| .TVMGetLastError as ctypes.FTVMGetLastError)(); |
| throw new Error("TVMError: " + this.memory.loadCString(msgPtr)); |
| } |
| } |
| |
| getOrAllocCallStack(): CachedCallStack { |
| if (this.recycledCallStacks.length != 0) { |
| return this.recycledCallStacks.pop() as CachedCallStack; |
| } |
| return new CachedCallStack( |
| this.memory, |
| this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, |
| this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace |
| ); |
| } |
| |
| recycleCallStack(callstack: CachedCallStack): void { |
| callstack.reset(); |
| this.recycledCallStacks.push(callstack); |
| } |
| |
| private validateInstance(): void { |
| this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace", "TVMFuncFree"]); |
| } |
| |
| private checkExports(funcNames: Array<string>): void { |
| const missList = []; |
| for (const name of funcNames) { |
| const f = this.exports[name]; |
| if (!(f instanceof Function)) { |
| missList.push(name); |
| } |
| } |
| if (missList.length != 0) { |
| throw new Error("Cannot find " + missList + " in exports"); |
| } |
| } |
| |
| private detectWasmMemory( |
| instance: WebAssembly.Instance, |
| imports: Record<string, any> |
| ): WebAssembly.Memory { |
| if (instance.exports.memory instanceof WebAssembly.Memory) { |
| return instance.exports.memory; |
| } |
| if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { |
| return imports.env.memory; |
| } |
| |
| throw new Error( |
| "Cannt detect wasm memory from imports " + |
| imports + |
| " or exports" + |
| instance.exports |
| ); |
| } |
| } |
| |
| /** |
| * A typed scalar constant used to represent a typed number |
| * argument to PackedFunc calls. |
| */ |
| export class Scalar { |
| /** The value. */ |
| value: number; |
| /** The data type of the scalar. */ |
| dtype: string; |
| |
| constructor(value: number, dtype: string) { |
| this.value = value; |
| this.dtype = dtype; |
| } |
| } |
| |
| /** |
| * Cell holds the PackedFunc object. |
| */ |
| class PackedFuncCell implements Disposable { |
| handle: Pointer; |
| private lib: FFILibrary; |
| |
| constructor(handle: Pointer, lib: FFILibrary) { |
| this.handle = handle; |
| this.lib = lib; |
| } |
| |
| dispose(): void { |
| if (this.handle != 0) { |
| this.lib.checkCall( |
| (this.lib.exports.TVMFuncFree as ctypes.FTVMFuncFree)(this.handle) |
| ); |
| this.handle = 0; |
| } |
| } |
| } |
| |
| const DeviceEnumToStr: Record<number, string> = { |
| 1: "cpu", |
| 2: "gpu", |
| 4: "opencl", |
| 8: "metal", |
| 15: "webgpu" |
| }; |
| |
| const DeviceStrToEnum: Record<string, number> = { |
| cpu: 1, |
| gpu: 2, |
| cuda: 2, |
| cl: 4, |
| opencl: 4, |
| vulkan: 7, |
| metal: 8, |
| webgpu: 15 |
| }; |
| |
| /** |
| * Represent a runtime context where a NDArray can reside. |
| */ |
| export class DLContext { |
| /** The device type code of the context. */ |
| deviceType: number; |
| /** The device index. */ |
| deviceId: number; |
| |
| private lib: FFILibrary; |
| |
| constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { |
| const tp = typeof deviceType; |
| if (tp == "string") { |
| this.deviceType = DeviceStrToEnum[deviceType]; |
| if (this.deviceType == undefined) { |
| throw new Error("Cannot recogonize deviceType " + deviceType); |
| } |
| } else if (tp == "number") { |
| this.deviceType = deviceType as number; |
| } else { |
| throw new Error("Cannot take type " + tp + " as deviceType"); |
| } |
| this.deviceId = deviceId; |
| this.lib = lib; |
| } |
| |
| /** |
| * Synchronize the context |
| */ |
| async sync(): Promise<void> { |
| if (this.deviceType == DeviceStrToEnum.webgpu) { |
| assert(this.lib.webGPUContext !== undefined); |
| await this.lib.webGPUContext.sync(); |
| } |
| } |
| |
| toString(): string { |
| return ( |
| DeviceEnumToStr[this.deviceType] + "(" + this.deviceId.toString() + ")" |
| ); |
| } |
| } |
| /** |
| * The data type code in DLDataType |
| */ |
| export const enum DLDataTypeCode { |
| Int = 0, |
| UInt = 1, |
| Float = 2, |
| OpaqueHandle = 3 |
| } |
| |
| const DLDataTypeCodeToStr: Record<number, string> = { |
| 0: "int", |
| 1: "uint", |
| 2: "float", |
| 3: "handle", |
| }; |
| |
| /** |
| * Runtime data type of NDArray. |
| */ |
| export class DLDataType { |
| /** The type code */ |
| code: number; |
| /** Number of bits in the data type. */ |
| bits: number; |
| /** Number of vector lanes. */ |
| lanes: number; |
| |
| constructor(code: number, bits: number, lanes: number) { |
| this.code = code; |
| this.bits = bits; |
| this.lanes = lanes; |
| } |
| |
| toString(): string { |
| const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); |
| if (this.lanes != 1) { |
| return ret + "x" + this.lanes.toString(); |
| } else { |
| return ret; |
| } |
| } |
| |
| numStorageBytes(): number { |
| return (this.bits * this.lanes + 7) >> 3; |
| } |
| } |
| |
| /** |
| * n-dimnesional array. |
| */ |
| export class NDArray implements Disposable { |
| /** Internal array handle. */ |
| handle: Pointer; |
| /** Number of dimensions. */ |
| ndim: number; |
| /** Data type of the array. */ |
| dtype: string; |
| /** Shape of the array. */ |
| shape: Array<number>; |
| /** Context of the array. */ |
| context: DLContext; |
| /** Whether it is a temporary view that can become invalid after the call. */ |
| private isView: boolean; |
| private byteOffset: number; |
| private dltensor: Pointer; |
| private dataPtr: Pointer; |
| private lib: FFILibrary; |
| private dlDataType: DLDataType; |
| |
| constructor(handle: Pointer, isView: boolean, lib: FFILibrary) { |
| this.handle = handle; |
| this.isView = isView; |
| this.lib = lib; |
| |
| if (this.isView) { |
| this.dltensor = handle; |
| } else { |
| this.dltensor = this.getDLTensorFromArrayHandle(this.handle); |
| } |
| // constant offsets. |
| const arrayOffsetData = 0; |
| const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); |
| const arrayOffsetDevType = arrayOffsetContext; |
| const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; |
| const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLContext; |
| const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; |
| const arrayOffsetDtypeCode = arrayOffsetDtype; |
| const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; |
| const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; |
| const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; |
| const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); |
| const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); |
| // dataPtr |
| this.dataPtr = lib.memory.loadPointer(this.dltensor); |
| // ndim |
| this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); |
| // shape |
| const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); |
| this.shape = []; |
| for (let i = 0; i < this.ndim; ++i) { |
| this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); |
| } |
| // dtype |
| const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); |
| const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); |
| const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); |
| this.dlDataType = new DLDataType(code, bits, lanes); |
| this.dtype = this.dlDataType.toString(); |
| |
| // ctx |
| const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); |
| const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); |
| this.context = new DLContext(deviceType, deviceId, lib); |
| |
| // byte_offset |
| this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); |
| } |
| |
| dispose(): void { |
| if (this.handle != 0 && !this.isView) { |
| this.lib.checkCall( |
| (this.lib.exports.TVMArrayFree as ctypes.FTVMArrayFree)(this.handle) |
| ); |
| this.handle = 0; |
| } |
| } |
| /** |
| * Copy data from another NDArray or javascript array. |
| * The number of elements must match. |
| * |
| * @param data The source data array. |
| * @returns this |
| */ |
| copyFrom(data: NDArray | Array<number> | Float32Array): this { |
| if (data instanceof NDArray) { |
| this.lib.checkCall( |
| (this.lib.exports.TVMArrayCopyFromTo as ctypes.FTVMArrayCopyFromTo)( |
| data.handle, |
| this.handle, |
| 0 |
| ) |
| ); |
| return this; |
| } else { |
| const size = this.shape.reduce((a, b) => { |
| return a * b; |
| }, 1); |
| if (data.length != size) { |
| throw new Error( |
| "data size and shape mismatch data.length" + |
| data.length + |
| " vs " + |
| size |
| ); |
| } |
| let buffer: ArrayBuffer; |
| if (this.dtype == "float32") { |
| buffer = Float32Array.from(data).buffer; |
| } else if (this.dtype == "float64") { |
| buffer = Float64Array.from(data).buffer; |
| } else if (this.dtype == "int32") { |
| buffer = Int32Array.from(data).buffer; |
| } else if (this.dtype == "int8") { |
| buffer = Int8Array.from(data).buffer; |
| } else if (this.dtype == "uint8") { |
| buffer = Uint8Array.from(data).buffer; |
| } else { |
| throw new Error("Unsupported data type " + this.dtype); |
| } |
| return this.copyFromRawBytes(new Uint8Array(buffer)); |
| } |
| } |
| /** |
| * Copy data from raw bytes. |
| * @param data Uint8Array of bytes. |
| * @returns this |
| */ |
| copyFromRawBytes(data: Uint8Array): this { |
| const size = this.shape.reduce((a, b) => { |
| return a * b; |
| }, 1); |
| const nbytes = this.dlDataType.numStorageBytes() * size; |
| if (nbytes != data.length) { |
| throw new Error("Expect the data's length equals nbytes=" + nbytes); |
| } |
| |
| const stack = this.lib.getOrAllocCallStack(); |
| |
| const tempOffset = stack.allocRawBytes(nbytes); |
| const tempPtr = stack.ptrFromOffset(tempOffset); |
| this.lib.memory.storeRawBytes(tempPtr, data); |
| this.lib.checkCall( |
| (this.lib.exports.TVMArrayCopyFromBytes as ctypes.FTVMArrayCopyFromBytes)( |
| this.handle, |
| tempPtr, |
| nbytes |
| ) |
| ); |
| |
| this.lib.recycleCallStack(stack); |
| return this; |
| } |
| /** |
| * Return a copied Uint8Array of the raw bytes in the NDArray. |
| * @returns The result array. |
| */ |
| toRawBytes(): Uint8Array { |
| if (this.context.deviceType != DeviceStrToEnum.cpu) { |
| throw new Error("Can only synchronize copy for GPU array, use copyfrom instead."); |
| } |
| const size = this.shape.reduce((a, b) => { |
| return a * b; |
| }, 1); |
| |
| const nbytes = this.dlDataType.numStorageBytes() * size; |
| const stack = this.lib.getOrAllocCallStack(); |
| |
| const tempOffset = stack.allocRawBytes(nbytes); |
| const tempPtr = stack.ptrFromOffset(tempOffset); |
| this.lib.checkCall( |
| (this.lib.exports.TVMArrayCopyToBytes as ctypes.FTVMArrayCopyToBytes)( |
| this.handle, |
| tempPtr, |
| nbytes |
| ) |
| ); |
| const ret = this.lib.memory.loadRawBytes(tempPtr, nbytes); |
| |
| this.lib.recycleCallStack(stack); |
| return ret; |
| } |
| |
| /** |
| * Return a TypedArray copy of the NDArray, the specific type depends on |
| * the dtype of the NDArray. |
| * @returns The result array. |
| */ |
| toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { |
| const stype = this.dtype; |
| if (stype == "float32") { |
| return new Float32Array(this.toRawBytes().buffer); |
| } else if (stype == "float64") { |
| return new Float64Array(this.toRawBytes().buffer); |
| } else if (stype == "int32") { |
| return new Int32Array(this.toRawBytes().buffer); |
| } else if (stype == "int8") { |
| return new Int8Array(this.toRawBytes().buffer); |
| } else if (stype == "uint8") { |
| return new Uint8Array(this.toRawBytes().buffer); |
| } else { |
| throw new Error("Unsupported data type " + this.dtype); |
| } |
| } |
| |
| private getDLTensorFromArrayHandle(handle: Pointer): Pointer { |
| // Note: this depends on the NDArray C ABI. |
| // keep this function in case of ABI change. |
| return handle; |
| } |
| } |
| |
| /** |
| * Runtime Module. |
| */ |
| export class Module implements Disposable { |
| handle: Pointer; |
| private lib: FFILibrary; |
| private makePackedFunc: (ptr: Pointer) => PackedFunc; |
| |
| constructor( |
| handle: Pointer, |
| lib: FFILibrary, |
| makePackedFunc: (ptr: Pointer) => PackedFunc |
| ) { |
| this.handle = handle; |
| this.lib = lib; |
| this.makePackedFunc = makePackedFunc; |
| } |
| |
| dispose(): void { |
| if (this.handle != 0) { |
| this.lib.checkCall( |
| (this.lib.exports.TVMModFree as ctypes.FTVMModFree)(this.handle) |
| ); |
| this.handle = 0; |
| } |
| } |
| |
| /** |
| * Get a function in the module. |
| * @param name The name of the function. |
| * @returns The result function. |
| */ |
| getFunction(name: string): PackedFunc { |
| const stack = this.lib.getOrAllocCallStack(); |
| const nameOffset = stack.allocRawBytes(name.length + 1); |
| stack.storeRawBytes(nameOffset, StringToUint8Array(name)); |
| |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| |
| stack.commitToWasmMemory(outOffset); |
| |
| this.lib.checkCall( |
| (this.lib.exports.TVMModGetFunction as ctypes.FTVMModGetFunction)( |
| this.handle, |
| stack.ptrFromOffset(nameOffset), |
| 1, |
| outPtr |
| ) |
| ); |
| const handle = this.lib.memory.loadPointer(outPtr); |
| this.lib.recycleCallStack(stack); |
| if (handle == 0) { |
| throw Error("Cannot find function " + name); |
| } |
| const ret = this.makePackedFunc(handle); |
| return ret; |
| } |
| |
| /** |
| * Import another module into the current runtime module. |
| * @param mod The module to be imported. |
| */ |
| importModule(mod: Module): void { |
| this.lib.checkCall( |
| (this.lib.exports.TVMModImport as ctypes.FTVMModImport)( |
| this.handle, |
| mod.handle |
| ) |
| ); |
| } |
| } |
| |
| /** |
| * Graph runtime. |
| * |
| * This is a thin wrapper of the underlying TVM module. |
| * you can also directly call set_input, run, and get_output |
| * of underlying module functions |
| */ |
| class GraphRuntime implements Disposable { |
| module: Module; |
| private packedSetInput: PackedFunc; |
| private packedRun: PackedFunc; |
| private packedGetOutput: PackedFunc; |
| private packedLoadParams: PackedFunc; |
| |
| /** |
| * COnstructor |
| * @param module The underlying module. |
| */ |
| constructor(module: Module) { |
| this.module = module; |
| this.packedSetInput = module.getFunction("set_input"); |
| this.packedRun = module.getFunction("run"); |
| this.packedGetOutput = module.getFunction("get_output"); |
| this.packedLoadParams = module.getFunction("load_params"); |
| } |
| |
| dispose(): void { |
| this.packedSetInput.dispose(); |
| this.packedRun.dispose(); |
| this.packedGetOutput.dispose(); |
| } |
| |
| /** |
| * Set input to the executor. |
| * |
| * @param key The input key. |
| * @param value The value to get set. |
| */ |
| setInput(key: number | string, value: NDArray): void { |
| if (typeof key == "number") { |
| this.packedSetInput(new Scalar(key, "int32"), value); |
| } else { |
| this.packedSetInput(key, value); |
| |
| } |
| } |
| |
| /** |
| * Execute the underlying graph. |
| */ |
| run(): void { |
| this.packedRun(); |
| } |
| |
| /** |
| * Get index-th output. |
| * @param index The index number. |
| * @param out The optional output storage parameters. |
| * @returns The output array. |
| */ |
| getOutput(index: number, out: NDArray | undefined = undefined): NDArray { |
| if (out !== undefined) { |
| this.packedGetOutput(new Scalar(index, "int32"), out) |
| return out; |
| } else { |
| return this.packedGetOutput(new Scalar(index, "int32")); |
| } |
| } |
| |
| /** |
| * Load parameters from parameter binary. |
| * @param paramBinary The parameter binary. |
| */ |
| loadParams(paramBinary: Uint8Array): void { |
| this.packedLoadParams(paramBinary); |
| } |
| |
| /** |
| * Benchmark stable execution of the graph(without data copy). |
| * @params ctx The context to sync during each run. |
| * @number The number of times to compute the average. |
| * @repeat The number of times to repeat the run. |
| */ |
| async benchmarkRuns(ctx: DLContext, number=10, repeat=4): Promise<number[]> { |
| // Skip first run as it can involve GPU warmup and module loading time. |
| const perf = compact.getPeformance(); |
| const results = []; |
| this.run(); |
| await ctx.sync(); |
| for (let k = 0; k < repeat; ++k) { |
| const tstart = perf.now(); |
| for (let i = 0; i < number; ++i) { |
| this.run(); |
| } |
| await ctx.sync(); |
| const tend = perf.now(); |
| results.push((tend - tstart) / number); |
| } |
| return results; |
| } |
| } |
| |
| /** Code used as the first argument of the async callback. */ |
| const enum AyncCallbackCode { |
| kReturn = 4, |
| kException = 5, |
| } |
| |
| /** |
| * TVM runtime instance. |
| */ |
| export class Instance implements Disposable { |
| memory: Memory; |
| exports: Record<string, Function>; |
| private lib: FFILibrary; |
| private env: Environment; |
| |
| /** |
| * Internal function(registered by the runtime) |
| */ |
| private wasmCreateLibraryModule?: PackedFunc & |
| ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); |
| |
| /** |
| * Constructor |
| * |
| * importObject can also be a {@link LibraryProvider} object, |
| * a WASI object, or an object containing wasmLibraryProvider field. |
| * |
| * @param wasmModule The input module or instance. |
| * @param importObject The imports to initialize the wasmInstance if it is not provided. |
| * @param wasmInstance Additional wasm instance argument for deferred construction. |
| * @param env Directly specified environment module. |
| * |
| * @see Please use the async version {@link instantiate} when targeting browsers. |
| */ |
| constructor( |
| wasmModule: WebAssembly.Module, |
| importObject: Record<string, any> = {}, |
| wasmInstance?: WebAssembly.Instance, |
| env?: Environment |
| ) { |
| if (wasmInstance instanceof WebAssembly.Instance) { |
| assert( |
| env instanceof Environment, |
| "env must be provided when passing in instance" |
| ); |
| } else { |
| assert(env === undefined); |
| env = new Environment(importObject); |
| wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); |
| } |
| |
| env.start(wasmInstance); |
| this.env = env; |
| this.lib = new FFILibrary(wasmInstance, env.imports); |
| this.memory = this.lib.memory; |
| this.exports = this.lib.exports; |
| this.registerEnvGlobalPackedFuncs(); |
| } |
| |
| dispose(): void { |
| this.lib.dispose(); |
| } |
| /** |
| * Get system-wide library module in the wasm. |
| * System lib is a global module that contains self register functions in startup. |
| * @returns The system library module. |
| */ |
| systemLib(): Module { |
| const getSysLib = this.getGlobalFunc("runtime.SystemLib"); |
| const mod = getSysLib() as Module; |
| getSysLib.dispose(); |
| return mod; |
| } |
| /** |
| * List all the global function names registered in the runtime. |
| * @returns The name list. |
| */ |
| listGlobalFuncNames(): Array<string> { |
| const stack = this.lib.getOrAllocCallStack(); |
| |
| const outSizeOffset = stack.allocPtrArray(2); |
| |
| const outSizePtr = stack.ptrFromOffset(outSizeOffset); |
| const outArrayPtr = stack.ptrFromOffset( |
| outSizeOffset + this.lib.sizeofPtr() |
| ); |
| |
| this.lib.checkCall( |
| (this.exports.TVMFuncListGlobalNames as ctypes.FTVMFuncListGlobalNames)( |
| outSizePtr, |
| outArrayPtr |
| ) |
| ); |
| |
| const size = this.memory.loadI32(outSizePtr); |
| const array = this.memory.loadPointer(outArrayPtr); |
| const names: Array<string> = []; |
| |
| for (let i = 0; i < size; ++i) { |
| names.push( |
| this.memory.loadCString( |
| this.memory.loadPointer(array + this.lib.sizeofPtr() * i) |
| ) |
| ); |
| } |
| |
| this.lib.recycleCallStack(stack); |
| return names; |
| } |
| |
| /** |
| * Register function to be global function in tvm runtime. |
| * @param name The name of the function. |
| * @param f function to be registered. |
| * @param override Whether overwrite function in existing registry. |
| */ |
| registerFunc( |
| name: string, |
| func: PackedFunc | Function, |
| override = false |
| ): void { |
| const packedFunc = this.toPackedFunc(func); |
| const ioverride = override ? 1 : 0; |
| |
| const stack = this.lib.getOrAllocCallStack(); |
| const nameOffset = stack.allocRawBytes(name.length + 1); |
| stack.storeRawBytes(nameOffset, StringToUint8Array(name)); |
| stack.commitToWasmMemory(); |
| |
| this.lib.checkCall( |
| (this.lib.exports.TVMFuncRegisterGlobal as ctypes.FTVMFuncRegisterGlobal)( |
| stack.ptrFromOffset(nameOffset), |
| packedFunc._tvmPackedCell.handle, |
| ioverride |
| ) |
| ); |
| } |
| |
| /** |
| * Get global PackedFunc from the runtime. |
| * @param name The name of the function. |
| * @returns The result function. |
| */ |
| getGlobalFunc(name: string): PackedFunc { |
| const stack = this.lib.getOrAllocCallStack(); |
| const nameOffset = stack.allocRawBytes(name.length + 1); |
| stack.storeRawBytes(nameOffset, StringToUint8Array(name)); |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| |
| stack.commitToWasmMemory(outOffset); |
| |
| this.lib.checkCall( |
| (this.exports.TVMFuncGetGlobal as ctypes.FTVMFuncGetGlobal)( |
| stack.ptrFromOffset(nameOffset), |
| outPtr |
| ) |
| ); |
| const handle = this.memory.loadPointer(outPtr); |
| this.lib.recycleCallStack(stack); |
| if (handle == 0) { |
| throw Error("Cannot find global function " + name); |
| } |
| const ret = this.makePackedFunc(handle); |
| return ret; |
| } |
| |
| /** |
| * Check if func is PackedFunc. |
| * |
| * @param func The input. |
| * @returns The check result. |
| */ |
| isPackedFunc(func: unknown): boolean { |
| // eslint-disable-next-line no-prototype-builtins |
| return typeof func == "function" && func.hasOwnProperty("_tvmPackedCell"); |
| } |
| |
| /** |
| * Convert func to PackedFunc |
| * |
| * @param func Input function. |
| * @returns The converted function. |
| */ |
| toPackedFunc(func: Function): PackedFunc { |
| if (this.isPackedFunc(func)) return func as PackedFunc; |
| return this.createPackedFuncFromCFunc(this.wrapJSFuncAsPackedCFunc(func)); |
| } |
| |
| /** |
| * Convert dtype to {@link DLDataType} |
| * |
| * @param dtype The input dtype string or DLDataType. |
| * @returns The converted result. |
| */ |
| toDLDataType(dtype: string | DLDataType): DLDataType { |
| if (dtype instanceof DLDataType) return dtype; |
| if (typeof dtype == "string") { |
| let pattern = dtype; |
| let code, |
| bits = 32, |
| lanes = 1; |
| if (pattern.substring(0, 5) == "float") { |
| pattern = pattern.substring(5, pattern.length); |
| code = DLDataTypeCode.Float; |
| } else if (pattern.substring(0, 3) == "int") { |
| pattern = pattern.substring(3, pattern.length); |
| code = DLDataTypeCode.Int; |
| } else if (pattern.substring(0, 4) == "uint") { |
| pattern = pattern.substring(4, pattern.length); |
| code = DLDataTypeCode.UInt; |
| } else if (pattern.substring(0, 6) == "handle") { |
| pattern = pattern.substring(5, pattern.length); |
| code = DLDataTypeCode.OpaqueHandle; |
| bits = 64; |
| } else { |
| throw new Error("Unknown dtype " + dtype); |
| } |
| |
| const arr = pattern.split("x"); |
| if (arr.length >= 1) { |
| const parsed = parseInt(arr[0]); |
| if (parsed + "" == arr[0]) { |
| bits = parsed; |
| } |
| } |
| if (arr.length >= 2) { |
| lanes = parseInt(arr[1]); |
| } |
| return new DLDataType(code, bits, lanes); |
| } else { |
| throw new Error("Unknown dtype " + dtype); |
| } |
| } |
| |
| /** |
| * Create a new {@link Scalar} that can be passed to a PackedFunc. |
| * @param value The number value. |
| * @param dtype The dtype string. |
| * @returns The created scalar. |
| */ |
| scalar(value: number, dtype: string): Scalar { |
| return new Scalar(value, dtype); |
| } |
| |
| /** |
| * Create a new {@link DLContext} |
| * @param deviceType The device type. |
| * @param deviceId The device index. |
| * @returns The created context. |
| */ |
| context(deviceType: number | string, deviceId = 0): DLContext { |
| return new DLContext(deviceType, deviceId, this.lib); |
| } |
| |
| /** |
| * Create a new cpu {@link DLContext} |
| * @param deviceId The device index. |
| */ |
| cpu(deviceId = 0): DLContext { |
| return this.context("cpu", deviceId); |
| } |
| |
| /** |
| * Create a new webgpu {@link DLContext} |
| * @param deviceId The device index. |
| */ |
| webgpu(deviceId = 0): DLContext { |
| return this.context("webgpu", deviceId); |
| } |
| |
| /** |
| * Create an empty {@link NDArray} with given shape and dtype. |
| * |
| * @param shape The shape of the array. |
| * @param dtype The data type of the array. |
| * @param ctx The context of the ndarray. |
| * @returns The created ndarray. |
| */ |
| empty( |
| shape: Array<number> | number, |
| dtype: string | DLDataType = "float32", |
| ctx: DLContext = this.context("cpu", 0) |
| ): NDArray { |
| dtype = this.toDLDataType(dtype); |
| shape = typeof shape == "number" ? [shape] : shape; |
| |
| const stack = this.lib.getOrAllocCallStack(); |
| const shapeOffset = stack.allocRawBytes(shape.length * SizeOf.I64); |
| for (let i = 0; i < shape.length; ++i) { |
| stack.storeI64(shapeOffset + i * SizeOf.I64, shape[i]); |
| } |
| |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| stack.commitToWasmMemory(outOffset); |
| |
| this.lib.checkCall( |
| (this.exports.TVMArrayAlloc as ctypes.FTVMArrayAlloc)( |
| stack.ptrFromOffset(shapeOffset), |
| shape.length, |
| dtype.code, |
| dtype.bits, |
| dtype.lanes, |
| ctx.deviceType, |
| ctx.deviceId, |
| outPtr |
| ) |
| ); |
| const ret = new NDArray(this.memory.loadPointer(outPtr), false, this.lib); |
| this.lib.recycleCallStack(stack); |
| return ret; |
| } |
| |
| /** |
| * Create a new graph runtime. |
| * |
| * @param graphJson The graph runtime json file. |
| * @param lib The underlying library. |
| * @param ctx The execution context of the graph. |
| */ |
| createGraphRuntime( |
| graphJson: string, |
| lib: Module, |
| ctx: DLContext |
| ): GraphRuntime { |
| const fcreate = this.getGlobalFunc("tvm.graph_runtime.create"); |
| const module = fcreate( |
| graphJson, |
| lib, |
| this.scalar(ctx.deviceType, "int32"), |
| this.scalar(ctx.deviceId, "int32")) as Module; |
| return new GraphRuntime(module); |
| } |
| |
| |
| /** |
| * Register an asyncfunction to be global function in the server. |
| * @param name The name of the function. |
| * @param func function to be registered. |
| * @param override Whether overwrite function in existing registry. |
| * |
| * @note The async function will only be used for serving remote calls in the rpc. |
| */ |
| registerAsyncServerFunc( |
| name: string, |
| func: Function, |
| override = false |
| ): void { |
| const asyncVariant = (...args: Array<any>): void => { |
| const fargs = args.slice(0, args.length - 1); |
| const callback = args[args.length - 1] as PackedFunc; |
| const promise: Promise<any> = func(...fargs); |
| promise.then((rv: any) => { |
| callback(this.scalar(AyncCallbackCode.kReturn, "int32"), rv); |
| }); |
| }; |
| this.registerFunc("__async." + name, asyncVariant, override); |
| } |
| |
| /** |
| * Initialize webgpu in the runtime. |
| * @param device The given GPU device. |
| */ |
| initWebGPU(device: GPUDevice): void { |
| const webGPUContext = new WebGPUContext( |
| this.memory, device |
| ); |
| this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { |
| return webGPUContext.getDeviceAPI(name); |
| }); |
| this.registerFunc("wasm.WebGPUCreateShader", (info: string, data: Uint8Array) => { |
| return webGPUContext.createShader(info, data); |
| }); |
| this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { |
| await webGPUContext.sync(); |
| }); |
| this.lib.webGPUContext = webGPUContext; |
| } |
| |
| /** Register global packed functions needed by the backend to the env. */ |
| private registerEnvGlobalPackedFuncs(): void { |
| // Register the timer function to enable the time_evaluator. |
| const perf = compact.getPeformance(); |
| |
| // Helper function to time the finvoke |
| const timeExecution = async ( |
| finvoke: PackedFunc, |
| ctx: DLContext, |
| nstep: number, |
| repeat: number, |
| minRepeatMs: number |
| ): Promise<Uint8Array> => { |
| finvoke(this.scalar(1, "int32")); |
| await ctx.sync(); |
| const result = []; |
| let setupNumber: number = nstep; |
| |
| for (let i = 0; i < repeat; ++i) { |
| let durationMs = 0.0; |
| do { |
| if (durationMs > 0.0) { |
| setupNumber = Math.floor( |
| Math.max(minRepeatMs / (durationMs / nstep) + 1, nstep * 1.618) |
| ); |
| } |
| const tstart: number = perf.now(); |
| finvoke(this.scalar(setupNumber, "int32")); |
| await ctx.sync(); |
| const tend: number = perf.now(); |
| |
| durationMs = tend - tstart; |
| } while (durationMs < minRepeatMs); |
| const speed = durationMs / setupNumber / 1000; |
| result.push(speed); |
| } |
| const ret = new Float64Array(result.length); |
| ret.set(result); |
| return new Uint8Array(ret.buffer); |
| }; |
| |
| const addOne = async (x: number): Promise<number> => { |
| await new Promise(resolve => setTimeout(resolve, 100)); |
| return x + 1; |
| }; |
| |
| this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); |
| this.registerAsyncServerFunc("testing.asyncAddOne", addOne); |
| } |
| |
| private createPackedFuncFromCFunc( |
| func: ctypes.FTVMWasmPackedCFunc |
| ): PackedFunc { |
| let findex = this.env.packedCFuncTable.length; |
| if (this.env.packedCFuncTableFreeId.length != 0) { |
| findex = this.env.packedCFuncTableFreeId.pop() as number; |
| } else { |
| this.env.packedCFuncTable.push(undefined); |
| } |
| this.env.packedCFuncTable[findex] = func; |
| |
| const stack = this.lib.getOrAllocCallStack(); |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| this.lib.checkCall( |
| (this.exports |
| .TVMWasmFuncCreateFromCFunc as ctypes.FTVMWasmFuncCreateFromCFunc)( |
| findex, |
| outPtr |
| ) |
| ); |
| const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); |
| this.lib.recycleCallStack(stack); |
| return ret; |
| } |
| |
| /** |
| * Set packed function arguments into the location indicated by argsValue and argsCode. |
| * Allocate new temporary space from the stack if necessary. |
| * |
| * @parma stack The call stack |
| * @param args The input arguments. |
| * @param argsValue The offset of argsValue. |
| * @param argsCode The offset of argsCode. |
| */ |
| setPackedArguments( |
| stack: CachedCallStack, |
| args: Array<any>, |
| argsValue: PtrOffset, |
| argsCode: PtrOffset |
| ): void { |
| for (let i = 0; i < args.length; ++i) { |
| let val = args[i]; |
| const tp = typeof val; |
| const valueOffset = argsValue + i * SizeOf.TVMValue; |
| const codeOffset = argsCode + i * SizeOf.I32; |
| if (val instanceof NDArray) { |
| stack.storePtr(valueOffset, val.handle); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMNDArrayHandle); |
| } else if (val instanceof Scalar) { |
| if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { |
| stack.storeI64(valueOffset, val.value); |
| stack.storeI32(codeOffset, ArgTypeCode.Int); |
| } else if (val.dtype.startsWith("float")) { |
| stack.storeF64(valueOffset, val.value); |
| stack.storeI32(codeOffset, ArgTypeCode.Float); |
| } else { |
| assert(val.dtype == "handle", "Expect handle"); |
| stack.storePtr(valueOffset, val.value); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMOpaqueHandle); |
| } |
| } else if (val instanceof DLContext) { |
| stack.storeI32(valueOffset, val.deviceType); |
| stack.storeI32(valueOffset + SizeOf.I32, val.deviceType); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMContext); |
| } else if (tp == "number") { |
| stack.storeF64(valueOffset, val); |
| stack.storeI32(codeOffset, ArgTypeCode.Float); |
| // eslint-disable-next-line no-prototype-builtins |
| } else if (tp == "function" && val.hasOwnProperty("_tvmPackedCell")) { |
| stack.storePtr(valueOffset, val._tvmPackedCell.handle); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); |
| } else if (val === null || val == undefined) { |
| stack.storePtr(valueOffset, 0); |
| stack.storeI32(codeOffset, ArgTypeCode.Null); |
| } else if (tp == "string") { |
| stack.allocThenSetArgString(valueOffset, val); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMStr); |
| } else if (val instanceof Uint8Array) { |
| stack.allocThenSetArgBytes(valueOffset, val); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMBytes); |
| } else if (val instanceof Function) { |
| val = this.toPackedFunc(val); |
| stack.tempArgs.push(val); |
| stack.storePtr(valueOffset, val._tvmPackedCell.handle); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMPackedFuncHandle); |
| } else if (val instanceof Module) { |
| stack.storePtr(valueOffset, val.handle); |
| stack.storeI32(codeOffset, ArgTypeCode.TVMModuleHandle); |
| } else { |
| throw new Error("Unsupported argument type " + tp); |
| } |
| } |
| } |
| |
| private wrapJSFuncAsPackedCFunc(func: Function): ctypes.FTVMWasmPackedCFunc { |
| const lib = this.lib; |
| return ( |
| argValues: Pointer, |
| argCodes: Pointer, |
| nargs: number, |
| ret: Pointer, |
| // eslint-disable-next-line @typescript-eslint/no-unused-vars |
| _handle: Pointer |
| ): number => { |
| const jsArgs = []; |
| for (let i = 0; i < nargs; ++i) { |
| const valuePtr = argValues + i * SizeOf.TVMValue; |
| const codePtr = argCodes + i * SizeOf.I32; |
| let tcode = lib.memory.loadI32(codePtr); |
| |
| if ( |
| tcode == ArgTypeCode.TVMObjectHandle || |
| tcode == ArgTypeCode.TVMObjectRValueRefArg || |
| tcode == ArgTypeCode.TVMPackedFuncHandle || |
| tcode == ArgTypeCode.TVMModuleHandle |
| ) { |
| lib.checkCall( |
| (lib.exports.TVMCbArgToReturn as ctypes.FTVMCbArgToReturn)( |
| valuePtr, |
| codePtr |
| ) |
| ); |
| } |
| tcode = lib.memory.loadI32(codePtr); |
| jsArgs.push(this.retValueToJS(valuePtr, tcode, true)); |
| } |
| |
| const rv = func(...jsArgs); |
| |
| if (rv !== undefined && rv !== null) { |
| const stack = lib.getOrAllocCallStack(); |
| const valueOffset = stack.allocRawBytes(SizeOf.TVMValue); |
| const codeOffset = stack.allocRawBytes(SizeOf.I32); |
| this.setPackedArguments(stack, [rv], valueOffset, codeOffset); |
| const valuePtr = stack.ptrFromOffset(valueOffset); |
| const codePtr = stack.ptrFromOffset(codeOffset); |
| stack.commitToWasmMemory(); |
| lib.checkCall( |
| (lib.exports.TVMCFuncSetReturn as ctypes.FTVMCFuncSetReturn)( |
| ret, |
| valuePtr, |
| codePtr, |
| 1 |
| ) |
| ); |
| lib.recycleCallStack(stack); |
| } |
| return 0; |
| }; |
| } |
| |
| private makePackedFunc(handle: Pointer): PackedFunc { |
| const cell = new PackedFuncCell(handle, this.lib); |
| |
| const packedFunc = (...args: any): any => { |
| const stack = this.lib.getOrAllocCallStack(); |
| |
| const valueOffset = stack.allocRawBytes(SizeOf.TVMValue * args.length); |
| const tcodeOffset = stack.allocRawBytes(SizeOf.I32 * args.length); |
| |
| this.setPackedArguments(stack, args, valueOffset, tcodeOffset); |
| |
| const rvalueOffset = stack.allocRawBytes(SizeOf.TVMValue); |
| const rcodeOffset = stack.allocRawBytes(SizeOf.I32); |
| const rvaluePtr = stack.ptrFromOffset(rvalueOffset); |
| const rcodePtr = stack.ptrFromOffset(rcodeOffset); |
| |
| // commit to wasm memory, till rvalueOffset (the return value don't need to be committed) |
| stack.commitToWasmMemory(rvalueOffset); |
| |
| this.lib.checkCall( |
| (this.exports.TVMFuncCall as ctypes.FTVMFuncCall)( |
| handle, |
| stack.ptrFromOffset(valueOffset), |
| stack.ptrFromOffset(tcodeOffset), |
| args.length, |
| rvaluePtr, |
| rcodePtr |
| ) |
| ); |
| |
| const ret = this.retValueToJS(rvaluePtr, this.memory.loadI32(rcodePtr), false); |
| this.lib.recycleCallStack(stack); |
| return ret; |
| }; |
| // Attach attributes to the function type. |
| // This is because javascript do not allow us to overload call. |
| const ret: any = packedFunc; |
| ret.dispose = (): void => { |
| cell.dispose(); |
| }; |
| ret._tvmPackedCell = cell; |
| return ret as PackedFunc; |
| } |
| |
| private retValueToJS(rvaluePtr: Pointer, tcode: number, callbackArg: boolean): any { |
| switch (tcode) { |
| case ArgTypeCode.Int: |
| case ArgTypeCode.UInt: |
| return this.memory.loadI64(rvaluePtr); |
| case ArgTypeCode.Float: |
| return this.memory.loadF64(rvaluePtr); |
| case ArgTypeCode.TVMOpaqueHandle: { |
| return this.memory.loadPointer(rvaluePtr); |
| } |
| case ArgTypeCode.TVMNDArrayHandle: { |
| return new NDArray(this.memory.loadPointer(rvaluePtr), false, this.lib); |
| } |
| case ArgTypeCode.TVMDLTensorHandle: { |
| assert(callbackArg); |
| return new NDArray(this.memory.loadPointer(rvaluePtr), true, this.lib); |
| } |
| case ArgTypeCode.TVMPackedFuncHandle: { |
| return this.makePackedFunc(this.memory.loadPointer(rvaluePtr)); |
| } |
| case ArgTypeCode.TVMModuleHandle: { |
| return new Module( |
| this.memory.loadPointer(rvaluePtr), |
| this.lib, |
| (ptr: Pointer) => { |
| return this.makePackedFunc(ptr); |
| } |
| ); |
| } |
| case ArgTypeCode.Null: return undefined; |
| case ArgTypeCode.TVMContext: { |
| const deviceType = this.memory.loadI32(rvaluePtr); |
| const deviceId = this.memory.loadI32(rvaluePtr + SizeOf.I32); |
| return this.context(deviceType, deviceId); |
| } |
| case ArgTypeCode.TVMStr: { |
| const ret = this.memory.loadCString(this.memory.loadPointer(rvaluePtr)); |
| return ret; |
| } |
| case ArgTypeCode.TVMBytes: { |
| return this.memory.loadTVMBytes(this.memory.loadPointer(rvaluePtr)); |
| } |
| default: |
| throw new Error("Unsupported return type code=" + tcode); |
| } |
| } |
| } |
| |
| /** |
| * Asynchrously instantiate a new {@link Instance}. |
| * |
| * importObject can also be a {@link LibraryProvider} object, |
| * a WASI object, or an object containing wasmLibraryProvider field. |
| * We can take benefit of syslib implementations from the Emscripten |
| * by passing its generated js Module as the imports. |
| * |
| * @param bufferSource The source to be compiled. |
| * @param importObject The import objects. |
| * @param logger The system logger. |
| */ |
| export function instantiate( |
| bufferSource: ArrayBuffer, |
| importObject: Record<string, any> = {}, |
| logger: (msg: string) => void = console.log |
| ): Promise<Instance> { |
| const env = new Environment(importObject, logger); |
| |
| return WebAssembly.instantiate(bufferSource, env.imports).then( |
| (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { |
| return new Instance(result.module, {}, result.instance, env); |
| } |
| ); |
| } |