| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| /** |
| * TVM JS Wasm Runtime library. |
| */ |
| import { Pointer, PtrOffset, SizeOf, TypeIndex } from "./ctypes"; |
| import { Disposable } from "./types"; |
| import { Memory, CachedCallStack } from "./memory"; |
| import { assert, StringToUint8Array, LinearCongruentialGenerator } from "./support"; |
| import { Environment } from "./environment"; |
| import { AsyncifyHandler } from "./asyncify"; |
| import { FunctionInfo, WebGPUContext } from "./webgpu"; |
| import { |
| ArtifactCache, |
| ArtifactCacheTemplate, |
| ArtifactIndexedDBCache, |
| TensorShardEntry, |
| } from "./artifact_cache"; |
| import * as compact from "./compact"; |
| import * as ctypes from "./ctypes"; |
| |
| /** |
| * Type for PackedFunc in the TVMRuntime. |
| */ |
| export type PackedFunc = ((...args: any) => any) & |
| Disposable & { _tvmPackedCell: PackedFuncCell }; |
| |
| /** |
| * Type for AyncPackedFunc in TVMRuntime |
| * possibly may contain stack unwinding through Asynctify |
| */ |
| export type AsyncPackedFunc = ((...args: any) => Promise<any>) & |
| Disposable & { _tvmPackedCell: PackedFuncCell }; |
| |
| /** |
| * @internal |
| * FFI Library wrapper, maintains most runtime states. |
| */ |
| class FFILibrary implements Disposable { |
| wasm32: boolean; |
| memory: Memory; |
| exports: Record<string, Function>; |
| webGPUContext?: WebGPUContext; |
| private wasmInstance: WebAssembly.Instance; |
| private recycledCallStacks: Array<CachedCallStack> = []; |
| |
| constructor( |
| wasmInstance: WebAssembly.Instance, |
| imports: Record<string, any> |
| ) { |
| this.wasmInstance = wasmInstance; |
| this.memory = new Memory(this.detectWasmMemory(this.wasmInstance, imports)); |
| assert( |
| this.wasmInstance.exports !== undefined, |
| "Expect the library module contains exports" |
| ); |
| this.exports = this.wasmInstance.exports as Record<string, Function>; |
| this.wasm32 = this.memory.wasm32; |
| this.validateInstance(); |
| } |
| |
| dispose(): void { |
| while (this.recycledCallStacks.length != 0) { |
| (this.recycledCallStacks.pop() as Disposable).dispose(); |
| } |
| this.webGPUContext?.dispose(); |
| } |
| |
| sizeofPtr(): number { |
| return this.memory.sizeofPtr(); |
| } |
| |
| checkCall(code: number): void { |
| if (code != 0) { |
| const msgPtr = (this.exports |
| .TVMFFIWasmGetLastError as ctypes.FTVMFFIWasmGetLastError)(); |
| throw new Error(this.memory.loadCString(msgPtr)); |
| } |
| } |
| |
| getOrAllocCallStack(): CachedCallStack { |
| if (this.recycledCallStacks.length != 0) { |
| return this.recycledCallStacks.pop() as CachedCallStack; |
| } |
| return new CachedCallStack( |
| this.memory, |
| this.exports.TVMWasmAllocSpace as ctypes.FTVMWasmAllocSpace, |
| this.exports.TVMWasmFreeSpace as ctypes.FTVMWasmFreeSpace |
| ); |
| } |
| |
| recycleCallStack(callstack: CachedCallStack): void { |
| callstack.reset(); |
| this.recycledCallStacks.push(callstack); |
| } |
| |
| private validateInstance(): void { |
| this.checkExports(["TVMWasmAllocSpace", "TVMWasmFreeSpace"]); |
| } |
| |
| private checkExports(funcNames: Array<string>): void { |
| const missList = []; |
| for (const name of funcNames) { |
| const f = this.exports[name]; |
| if (!(f instanceof Function)) { |
| missList.push(name); |
| } |
| } |
| if (missList.length != 0) { |
| throw new Error("Cannot find " + missList + " in exports"); |
| } |
| } |
| |
| private detectWasmMemory( |
| instance: WebAssembly.Instance, |
| imports: Record<string, any> |
| ): WebAssembly.Memory { |
| if (instance.exports.memory instanceof WebAssembly.Memory) { |
| return instance.exports.memory; |
| } |
| if (imports.env && imports.env.memory instanceof WebAssembly.Memory) { |
| return imports.env.memory; |
| } |
| |
| throw new Error( |
| "Cannt detect wasm memory from imports " + |
| imports + |
| " or exports" + |
| instance.exports |
| ); |
| } |
| } |
| |
| /** |
| * @internal |
| * Manages extra runtime context for the runtime. |
| */ |
| class RuntimeContext implements Disposable { |
| functionListGlobalNamesFunctor: PackedFunc; |
| moduleGetFunction: PackedFunc; |
| moduleImport: PackedFunc; |
| tensorEmpty: PackedFunc; |
| tensorCopyFromTo: PackedFunc; |
| tensorCopyFromJSBytes: PackedFunc; |
| tensorCopyToJSBytes: PackedFunc; |
| arrayGetItem: PackedFunc; |
| arrayGetSize: PackedFunc; |
| arrayMake: PackedFunc; |
| arrayConcat: PackedFunc; |
| getSysLib: PackedFunc; |
| tensorCacheGet: PackedFunc; |
| tensorCacheUpdate: PackedFunc; |
| tensorCacheRemove: PackedFunc; |
| tensorCacheClear: PackedFunc; |
| arrayDecodeStorage: PackedFunc; |
| paramModuleFromCache: PackedFunc; |
| paramModuleFromCacheByName: PackedFunc; |
| makeShapeTuple: PackedFunc; |
| tensorCreateView: PackedFunc; |
| sampleTopPFromLogits: PackedFunc; |
| sampleTopPFromProb: PackedFunc; |
| applyRepetitionPenalty: PackedFunc; |
| applyPresenceAndFrequencyPenalty: PackedFunc; |
| applySoftmaxWithTemperature: PackedFunc; |
| concatEmbeddings: PackedFunc | undefined; |
| bool: PackedFunc; |
| private autoDisposeScope: Array<Array<Disposable | undefined>> = []; |
| |
| constructor( |
| getGlobalFunc: (name: string) => PackedFunc |
| ) { |
| this.functionListGlobalNamesFunctor = getGlobalFunc( |
| "ffi.FunctionListGlobalNamesFunctor" |
| ); |
| this.moduleGetFunction = getGlobalFunc("ffi.ModuleGetFunction"); |
| this.moduleImport = getGlobalFunc("ffi.ModuleImportModule"); |
| this.tensorEmpty = getGlobalFunc("runtime.TVMTensorAllocWithScope"); |
| this.tensorCopyFromTo = getGlobalFunc("runtime.TVMTensorCopyFromTo"); |
| this.tensorCopyFromJSBytes = getGlobalFunc("tvmjs.runtime.NDTensorCopyFromBytes"); |
| this.tensorCopyToJSBytes = getGlobalFunc("tvmjs.runtime.TensorCopyToBytes"); |
| this.arrayGetItem = getGlobalFunc("ffi.ArrayGetItem"); |
| this.arrayGetSize = getGlobalFunc("ffi.ArraySize"); |
| this.arrayMake = getGlobalFunc("ffi.Array"); |
| this.arrayConcat = getGlobalFunc("tvmjs.runtime.ArrayConcat"); |
| this.getSysLib = getGlobalFunc("ffi.SystemLib"); |
| this.tensorCacheGet = getGlobalFunc("vm.builtin.tensor_cache.get"); |
| this.tensorCacheRemove = getGlobalFunc("vm.builtin.tensor_cache.remove"); |
| this.tensorCacheUpdate = getGlobalFunc("vm.builtin.tensor_cache.update"); |
| this.tensorCacheClear = getGlobalFunc("vm.builtin.tensor_cache.clear"); |
| this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage"); |
| this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache"); |
| this.paramModuleFromCacheByName = getGlobalFunc("vm.builtin.param_module_from_cache_by_name"); |
| this.makeShapeTuple = getGlobalFunc("ffi.Shape"); |
| this.tensorCreateView = getGlobalFunc("runtime.TVMTensorCreateView"); |
| this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); |
| this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob"); |
| this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); |
| this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); |
| this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); |
| this.concatEmbeddings = getGlobalFunc("tvmjs.runtime.ConcatEmbeddings"); |
| } |
| |
| dispose(): void { |
| // call array cache clear to clear all cached items |
| this.tensorCacheClear.dispose(); |
| this.arrayGetItem.dispose(); |
| this.arrayGetSize.dispose(); |
| this.arrayMake.dispose(); |
| this.arrayConcat.dispose(); |
| this.tensorCacheGet.dispose(); |
| this.tensorCacheRemove.dispose(); |
| this.tensorCacheUpdate.dispose(); |
| this.tensorCacheClear.dispose(); |
| this.arrayDecodeStorage.dispose(); |
| this.paramModuleFromCache.dispose(); |
| this.paramModuleFromCacheByName.dispose(); |
| this.makeShapeTuple.dispose(); |
| this.tensorCreateView.dispose(); |
| this.sampleTopPFromLogits.dispose(); |
| this.applyRepetitionPenalty.dispose(); |
| this.applyPresenceAndFrequencyPenalty.dispose(); |
| this.applySoftmaxWithTemperature.dispose(); |
| this.concatEmbeddings?.dispose(); |
| } |
| |
| beginScope(): void { |
| this.autoDisposeScope.push([]); |
| } |
| |
| endScope(): void { |
| if (this.autoDisposeScope.length === 0) { |
| throw Error("tvm.endScope called when the stack is empty."); |
| } |
| // automatically dispose all the tracked values in the current scope. |
| const currScope = this.autoDisposeScope.pop() as Array<Disposable | undefined>; |
| for (let i = 0; i < currScope.length; ++i) { |
| const val = currScope[i]; |
| if (val !== undefined) { |
| val.dispose(); |
| } |
| } |
| } |
| |
| /** |
| * Track object for dispose in current scope. |
| * |
| * @param obj The object to be tracked. |
| * @returns the same object. |
| * @note This function only needs to be called for raw system C API values. |
| * The return value of PackedFunc will be automatically tracked. |
| */ |
| attachToCurrentScope<T extends Disposable>(obj: T): T { |
| if (this.autoDisposeScope.length === 0) { |
| throw Error("Must call beginScope to use functions that returns TVM objects"); |
| } |
| const currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; |
| currScope.push(obj); |
| return obj; |
| } |
| |
| moveToParentScope<T extends Disposable>(obj: T): T { |
| this.detachFromCurrentScope(obj); |
| if (this.autoDisposeScope.length < 2) { |
| throw Error("moveToParentScope: Parent scope do not exist"); |
| } |
| const parentScope = this.autoDisposeScope[this.autoDisposeScope.length - 2]; |
| parentScope.push(obj); |
| return obj; |
| } |
| |
| detachFromCurrentScope<T extends Disposable>(obj: T): T { |
| const currScope = this.autoDisposeScope[this.autoDisposeScope.length - 1]; |
| let occurrence = 0; |
| for (let i = 0; i < currScope.length; ++i) { |
| if (currScope[i] === obj) { |
| occurrence += 1; |
| currScope[i] = undefined; |
| } |
| } |
| if (occurrence === 0) { |
| throw Error("Cannot find obj in the current auto conversion pool"); |
| } |
| if (occurrence > 1) { |
| throw Error("Value attached to scope multiple times"); |
| } |
| return obj; |
| } |
| } |
| |
| /** |
| * A typed scalar constant used to represent a typed number |
| * argument to PackedFunc calls. |
| */ |
| export class Scalar { |
| /** The value. */ |
| value: number; |
| /** The data type of the scalar. */ |
| dtype: string; |
| |
| constructor(value: number, dtype: string) { |
| this.value = value; |
| this.dtype = dtype; |
| } |
| } |
| |
| const DeviceEnumToStr: Record<number, string> = { |
| 1: "cpu", |
| 2: "cuda", |
| 4: "opencl", |
| 8: "metal", |
| 15: "webgpu" |
| }; |
| |
| const DeviceStrToEnum: Record<string, number> = { |
| cpu: 1, |
| cuda: 2, |
| cl: 4, |
| opencl: 4, |
| vulkan: 7, |
| metal: 8, |
| webgpu: 15 |
| }; |
| |
| /** |
| * Represent a runtime context where a Tensor can reside. |
| */ |
| export class DLDevice { |
| /** The device type code of the device. */ |
| deviceType: number; |
| /** The device index. */ |
| deviceId: number; |
| |
| private lib: FFILibrary; |
| |
| constructor(deviceType: number | string, deviceId: number, lib: FFILibrary) { |
| const tp = typeof deviceType; |
| if (tp === "string") { |
| this.deviceType = DeviceStrToEnum[deviceType]; |
| if (this.deviceType === undefined) { |
| throw new Error("Cannot recogonize deviceType " + deviceType); |
| } |
| } else if (tp === "number") { |
| this.deviceType = deviceType as number; |
| } else { |
| throw new Error("Cannot take type " + tp + " as deviceType"); |
| } |
| this.deviceId = deviceId; |
| this.lib = lib; |
| } |
| |
| /** |
| * Synchronize the device |
| */ |
| async sync(): Promise<void> { |
| if (this.deviceType === DeviceStrToEnum.webgpu) { |
| assert(this.lib.webGPUContext !== undefined); |
| await this.lib.webGPUContext.sync(); |
| } |
| } |
| |
| toString(): string { |
| return ( |
| DeviceEnumToStr[this.deviceType] + ":" + this.deviceId.toString() |
| ); |
| } |
| } |
| /** |
| * The data type code in DLDataType |
| */ |
| export enum DLDataTypeCode { |
| Int = 0, |
| UInt = 1, |
| Float = 2, |
| OpaqueHandle = 3 |
| } |
| |
| const DLDataTypeCodeToStr: Record<number, string> = { |
| 0: "int", |
| 1: "uint", |
| 2: "float", |
| 3: "handle", |
| }; |
| |
| /** |
| * Runtime data type of Tensor. |
| */ |
| export class DLDataType { |
| /** The type code */ |
| code: number; |
| /** Number of bits in the data type. */ |
| bits: number; |
| /** Number of vector lanes. */ |
| lanes: number; |
| |
| constructor(code: number, bits: number, lanes: number) { |
| this.code = code; |
| this.bits = bits; |
| this.lanes = lanes; |
| } |
| |
| toString(): string { |
| const ret = DLDataTypeCodeToStr[this.code] + this.bits.toString(); |
| if (this.lanes != 1) { |
| return ret + "x" + this.lanes.toString(); |
| } else { |
| return ret; |
| } |
| } |
| |
| numStorageBytes(): number { |
| return (this.bits * this.lanes + 7) >> 3; |
| } |
| } |
| |
| /** |
| * Generic object base |
| */ |
| export class TVMObject implements Disposable { |
| protected handle: Pointer; |
| protected lib: FFILibrary; |
| protected ctx: RuntimeContext; |
| |
| constructor( |
| handle: Pointer, |
| lib: FFILibrary, |
| ctx: RuntimeContext |
| ) { |
| this.handle = handle; |
| this.lib = lib; |
| this.ctx = ctx; |
| } |
| |
| dispose(): void { |
| if (this.handle != 0) { |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(this.handle) |
| ); |
| this.handle = 0; |
| } |
| } |
| |
| /** |
| * Get handle of module, check it is not null. |
| * |
| * @param requireNotNull require handle is not null. |
| * @returns The handle. |
| */ |
| getHandle(requireNotNull = true): Pointer { |
| if (requireNotNull && this.handle === 0) { |
| throw Error("Object has already been disposed"); |
| } |
| return this.handle; |
| } |
| |
| /** get the type index of the object */ |
| typeIndex(): number { |
| if (this.handle === 0) { |
| throw Error("The current Object has already been disposed"); |
| } |
| return this.lib.memory.loadObjectTypeIndex(this.handle); |
| } |
| |
| /** get the type key of the object */ |
| typeKey(): string { |
| const type_index = this.typeIndex(); |
| const typeInfoPtr = (this.lib.exports.TVMFFIGetTypeInfo as ctypes.FTVMFFIGetTypeInfo)( |
| type_index |
| ); |
| return this.lib.memory.loadTypeInfoTypeKey(typeInfoPtr); |
| } |
| } |
| |
| /** |
| * Cell holds the PackedFunc object. |
| */ |
| class PackedFuncCell extends TVMObject { |
| constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) { |
| super(handle, lib, ctx); |
| } |
| } |
| |
| /** |
| * Tensor( n-dimnesional array). |
| */ |
| |
| export class Tensor extends TVMObject { |
| /** Number of dimensions. */ |
| ndim: number; |
| /** Data type of the array. */ |
| dtype: string; |
| /** Shape of the array. */ |
| shape: Array<number>; |
| /** Device of the array. */ |
| device: DLDevice; |
| /** Whether it is a temporary view that can become invalid after the call. */ |
| isView: boolean; |
| private byteOffset: number; |
| private dltensor: Pointer; |
| private dataPtr: Pointer; |
| private dlDataType: DLDataType; |
| |
| constructor(handle: Pointer, lib: FFILibrary, ctx: RuntimeContext, isView: boolean) { |
| // if the array is a view, we need to create a new object with a null handle |
| // so dispose won't trigger memory free |
| const objectHandle = isView ? 0 : handle; |
| super(objectHandle, lib, ctx); |
| this.isView = isView; |
| if (this.isView) { |
| this.dltensor = handle; |
| } else { |
| this.dltensor = this.getDLTensorFromArrayHandle(this.handle); |
| } |
| // constant offsets. |
| const arrayOffsetData = 0; |
| const arrayOffsetContext = arrayOffsetData + this.lib.sizeofPtr(); |
| const arrayOffsetDevType = arrayOffsetContext; |
| const arrayOffsetDevId = arrayOffsetContext + SizeOf.I32; |
| const arrayOffsetNdim = arrayOffsetContext + SizeOf.DLDevice; |
| const arrayOffsetDtype = arrayOffsetNdim + SizeOf.I32; |
| const arrayOffsetDtypeCode = arrayOffsetDtype; |
| const arrayOffsetDtypeBits = arrayOffsetDtype + SizeOf.U8; |
| const arrayOffsetDtypeLanes = arrayOffsetDtypeBits + SizeOf.U8; |
| const arrayOffsetShape = arrayOffsetDtype + SizeOf.DLDataType; |
| const arrayOffsetStrides = arrayOffsetShape + this.lib.sizeofPtr(); |
| const arrayOffsetByteOffset = arrayOffsetStrides + this.lib.sizeofPtr(); |
| // dataPtr |
| this.dataPtr = lib.memory.loadPointer(this.dltensor); |
| // ndim |
| this.ndim = lib.memory.loadI32(this.dltensor + arrayOffsetNdim); |
| // shape |
| const cshapePtr = lib.memory.loadPointer(this.dltensor + arrayOffsetShape); |
| this.shape = []; |
| for (let i = 0; i < this.ndim; ++i) { |
| this.shape.push(lib.memory.loadI64(cshapePtr + i * SizeOf.I64)); |
| } |
| // dtype |
| const code = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeCode); |
| const bits = lib.memory.loadU8(this.dltensor + arrayOffsetDtypeBits); |
| const lanes = lib.memory.loadU16(this.dltensor + arrayOffsetDtypeLanes); |
| this.dlDataType = new DLDataType(code, bits, lanes); |
| this.dtype = this.dlDataType.toString(); |
| |
| // device |
| const deviceType = lib.memory.loadI32(this.dltensor + arrayOffsetDevType); |
| const deviceId = lib.memory.loadI32(this.dltensor + arrayOffsetDevId); |
| this.device = new DLDevice(deviceType, deviceId, lib); |
| |
| // byte_offset |
| this.byteOffset = lib.memory.loadI64(this.dltensor + arrayOffsetByteOffset); |
| } |
| |
| /** |
| * Create a view of the array. |
| * @param shape The shape of the view. |
| * @param dtype The data type of the new array. |
| * @returns The new sliced ndarray. |
| */ |
| view(shape: Array<number>, dtype?: string): Tensor { |
| const shapeArray = shape.map((value) => new Scalar(value, "int")); |
| if (dtype === undefined) { |
| dtype = this.dtype; |
| } |
| return this.ctx.tensorCreateView( |
| this, |
| this.ctx.makeShapeTuple(...shapeArray), |
| this.dtype, |
| /*relative_byte_offset=*/ new Scalar(0, "int"), |
| ); |
| } |
| /** |
| * Get dataPtr of NDarray |
| * |
| * @returns The handle. |
| */ |
| getDataPtr(): Pointer { |
| if (this.handle === 0) { |
| throw Error("Tensor has already been disposed"); |
| } |
| return this.dataPtr; |
| } |
| |
| /** |
| * Copy data from another Tensor or javascript array. |
| * The number of elements must match. |
| * |
| * @param data The source data array. |
| * @returns this |
| */ |
| copyFrom( |
| data: Tensor | Array<number> | Float32Array | Float64Array | |
| Int32Array | Int8Array | Uint8Array | Uint8ClampedArray |
| ): this { |
| if (data instanceof Tensor) { |
| this.ctx.tensorCopyFromTo(data, this); |
| return this; |
| } else { |
| const size = this.shape.reduce((a, b) => { |
| return a * b; |
| }, 1); |
| if (data.length != size) { |
| throw new Error( |
| "data size and shape mismatch data.length" + |
| data.length + |
| " vs " + |
| size |
| ); |
| } |
| let buffer: ArrayBuffer; |
| if (this.dtype === "float32") { |
| buffer = Float32Array.from(data).buffer; |
| } else if (this.dtype === "float64") { |
| buffer = Float64Array.from(data).buffer; |
| } else if (this.dtype === "int32") { |
| buffer = Int32Array.from(data).buffer; |
| } else if (this.dtype === "int8") { |
| buffer = Int8Array.from(data).buffer; |
| } else if (this.dtype === "uint8") { |
| buffer = Uint8Array.from(data).buffer; |
| } else if (this.dtype === "uint32") { |
| buffer = Uint32Array.from(data).buffer; |
| } else { |
| throw new Error("Unsupported data type " + this.dtype); |
| } |
| return this.copyFromRawBytes(new Uint8Array(buffer)); |
| } |
| } |
| /** |
| * Copy data from raw bytes. |
| * @param data Uint8Array of bytes. |
| * @returns this |
| */ |
| copyFromRawBytes(data: Uint8Array): this { |
| // short cut for gpu copy |
| if (this.device.deviceType === DeviceStrToEnum.webgpu) { |
| this.lib.webGPUContext?.copyRawBytesToBuffer(data, this.getDataPtr(), 0, data.length); |
| return this; |
| } |
| // CPU copy |
| const size = this.shape.reduce((a, b) => { |
| return a * b; |
| }, 1); |
| const nbytes = this.dlDataType.numStorageBytes() * size; |
| if (nbytes != data.length) { |
| throw new Error("Expect the data's length equals nbytes=" + nbytes); |
| } |
| this.ctx.tensorCopyFromJSBytes(this, data); |
| return this; |
| } |
| /** |
| * Return a copied Uint8Array of the raw bytes in the Tensor. |
| * @returns The result array. |
| */ |
| toRawBytes(): Uint8Array { |
| if (this.device.deviceType != DeviceStrToEnum.cpu) { |
| throw new Error("Can only sync copy CPU array, use cpu_arr.copyfrom(gpu_arr) then sync instead."); |
| } |
| return this.ctx.tensorCopyToJSBytes(this) as Uint8Array; |
| } |
| |
| /** |
| * Return a TypedArray copy of the Tensor, the specific type depends on |
| * the dtype of the Tensor. |
| * @returns The result array. |
| */ |
| toArray(): Float32Array | Float64Array | Int32Array | Int8Array | Uint8Array { |
| const stype = this.dtype; |
| if (stype === "float32") { |
| return new Float32Array(this.toRawBytes().buffer); |
| } else if (stype === "float64") { |
| return new Float64Array(this.toRawBytes().buffer); |
| } else if (stype === "int32") { |
| return new Int32Array(this.toRawBytes().buffer); |
| } else if (stype === "int8") { |
| return new Int8Array(this.toRawBytes().buffer); |
| } else if (stype === "uint8") { |
| return new Uint8Array(this.toRawBytes().buffer); |
| } else { |
| throw new Error("Unsupported data type " + this.dtype); |
| } |
| } |
| |
| private getDLTensorFromArrayHandle(handle: Pointer): Pointer { |
| return handle + SizeOf.ObjectHeader; |
| } |
| } |
| |
| |
| /** |
| * Runtime Module. |
| */ |
| export class Module extends TVMObject { |
| constructor( |
| handle: Pointer, |
| lib: FFILibrary, |
| ctx: RuntimeContext, |
| ) { |
| super(handle, lib, ctx); |
| } |
| /** |
| * Get a function in the module. |
| * @param name The name of the function. |
| * @param queryImports Whether to also query imports |
| * @returns The result function. |
| */ |
| getFunction(name: string, queryImports = true): PackedFunc { |
| return this.ctx.moduleGetFunction(this, name, queryImports) as PackedFunc; |
| } |
| |
| /** |
| * Import another module into the current runtime module. |
| * @param mod The module to be imported. |
| */ |
| importModule(mod: Module): void { |
| this.ctx.moduleImport(this, mod); |
| } |
| } |
| |
| |
| /** Objectconstructor */ |
| type FObjectConstructor = (handle: Pointer, lib: FFILibrary, ctx: RuntimeContext) => TVMObject; |
| |
| /** All possible object types. */ |
| type TVMObjectBase = TVMObject | PackedFunc; |
| |
| /** Runtime array object. */ |
| export class TVMArray extends TVMObject { |
| constructor( |
| handle: Pointer, |
| lib: FFILibrary, |
| ctx: RuntimeContext |
| ) { |
| super(handle, lib, ctx); |
| } |
| |
| /** |
| * @returns the size of the array. |
| */ |
| size(): number { |
| return this.ctx.arrayGetSize(this) as number; |
| } |
| /** |
| * Get index-th element of the array |
| * @param index the array index. |
| * @returns The element. |
| */ |
| get(index: number): TVMObjectBase { |
| return this.ctx.arrayGetItem(this, new Scalar(index, "int32")) as TVMObjectBase; |
| } |
| } |
| |
| export enum VMAllocatorKind { |
| NAIVE_ALLOCATOR = 1, |
| POOLED_ALLOCATOR = 2, |
| } |
| |
| /** |
| * VirtualMachine Executor. |
| * |
| * This is a thin wrapper of the underlying TVM module. |
| * you can also directly call set_input, run, and get_output |
| * of underlying module functions |
| */ |
| export class VirtualMachine implements Disposable { |
| private mod: Module; |
| /** |
| * Constructor |
| * @param mod The underlying module, need to be detached. |
| * @param device The main device ro run VM on. |
| */ |
| constructor(mod: Module, device: DLDevice) { |
| this.mod = mod; |
| this.mod.getFunction("vm_initialization")( |
| new Scalar(device.deviceType, "int"), |
| new Scalar(device.deviceId, "int"), |
| new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"), |
| // explicitly specify host device type |
| new Scalar(DeviceStrToEnum.cpu, "int"), |
| new Scalar(0, "int"), |
| new Scalar(VMAllocatorKind.POOLED_ALLOCATOR, "int"), |
| ); |
| } |
| |
| dispose(): void { |
| this.mod.dispose(); |
| } |
| /** |
| * Get a function in the VM module. |
| * @param name The name of the function. |
| * @returns The result function. |
| */ |
| getFunction(name: string): PackedFunc { |
| return this.mod.getFunction(name); |
| } |
| |
| /** |
| * Get the internal module. |
| */ |
| getInternalModule(): Module { |
| return this.mod; |
| } |
| } |
| |
| /** Code used as the first argument of the async callback. */ |
| enum AsyncCallbackCode { |
| kReturn = 4, |
| kException = 5, |
| } |
| |
| export interface InitProgressReport { |
| progress: number; |
| timeElapsed: number; |
| text: string; |
| } |
| |
| export type InitProgressCallback = (report: InitProgressReport) => void; |
| |
| /** |
| * TVM runtime instance. |
| * |
| * All objects(Tensor, Module, PackedFunc) returned by TVM runtim function call |
| * and PackedFunc instance are tracked through a scope mechanism that will get |
| * auto-released when we call EndScope. |
| * |
| * This is necessarily to be able to release the underlying WASM and WebGPU memory that |
| * are not tracked through JS native garbage collection mechanism. |
| * |
| * This does mean that we have to get familar with the following functions: |
| * - {@link beginScope} |
| * - {@link endScope} |
| * - {@link withNewScope} |
| * - {@link attachToCurrentScope} |
| * - {@link detachFromCurrentScope} |
| */ |
| export class Instance implements Disposable { |
| memory: Memory; |
| exports: Record<string, Function>; |
| cacheMetadata: Record<string, any> = {}; |
| private lib: FFILibrary; |
| private env: Environment; |
| private objFactory: Map<number, FObjectConstructor>; |
| private ctx: RuntimeContext; |
| private asyncifyHandler: AsyncifyHandler; |
| private initProgressCallback: Array<InitProgressCallback> = []; |
| private rng: LinearCongruentialGenerator; |
| private deviceLostIsError = true; // whether device.lost is due to actual error or dispose() |
| |
| /** |
| * Internal function(registered by the runtime) |
| */ |
| private wasmCreateLibraryModule?: PackedFunc & |
| ((getFunc: PackedFunc, getGlobal: PackedFunc) => PackedFunc); |
| |
| /** |
| * Constructor |
| * |
| * importObject can also be a {@link LibraryProvider} object, |
| * a WASI object, or an object containing wasmLibraryProvider field. |
| * |
| * @param wasmModule The input module or instance. |
| * @param importObject The imports to initialize the wasmInstance if it is not provided. |
| * @param wasmInstance Additional wasm instance argument for deferred construction. |
| * @param env Directly specified environment module. |
| * |
| * @see Please use the async version {@link instantiate} when targeting browsers. |
| */ |
| constructor( |
| wasmModule: WebAssembly.Module, |
| importObject: Record<string, any> = {}, |
| wasmInstance?: WebAssembly.Instance, |
| env?: Environment |
| ) { |
| if (wasmInstance instanceof WebAssembly.Instance) { |
| assert( |
| env instanceof Environment, |
| "env must be provided when passing in instance" |
| ); |
| } else { |
| assert(env === undefined); |
| env = new Environment(importObject); |
| wasmInstance = new WebAssembly.Instance(wasmModule, env.imports); |
| } |
| env.start(wasmInstance); |
| this.env = env; |
| this.lib = new FFILibrary(wasmInstance, env.imports); |
| this.memory = this.lib.memory; |
| this.exports = this.lib.exports; |
| this.asyncifyHandler = new AsyncifyHandler(this.exports, this.memory.memory); |
| this.objFactory = new Map<number, ObjectConstructor>(); |
| this.ctx = new RuntimeContext( |
| (name: string) => { |
| const autoAttachToScope = false; |
| // runtime context function do not auto-release. |
| return this.getGlobalFuncInternal(name, autoAttachToScope); |
| } |
| ); |
| this.registerEnvGlobalPackedFuncs(); |
| this.registerObjectFactoryFuncs(); |
| this.rng = new LinearCongruentialGenerator(); |
| } |
| |
| /** |
| * Benchmark stable execution of the run function. |
| * |
| * @params run The run function |
| * @params dev The device to sync during each run. |
| * @number The number of times to compute the average. |
| * @repeat The number of times to repeat the run. |
| */ |
| async benchmark(run: () => void, dev: DLDevice, number = 10, repeat = 1): Promise<number[]> { |
| // Skip first run as it can involve GPU warmup and module loading time. |
| const perf = compact.getPerformance(); |
| const results = []; |
| |
| // run with new scope |
| this.withNewScope(run); |
| await dev.sync(); |
| |
| for (let k = 0; k < repeat; ++k) { |
| const tstart = perf.now(); |
| for (let i = 0; i < number; ++i) { |
| this.withNewScope(run); |
| } |
| await dev.sync(); |
| const tend = perf.now(); |
| results.push((tend - tstart) / number); |
| } |
| return results; |
| } |
| |
| /** |
| * Check whether we enabled asyncify mode |
| * @returns The asynctify mode toggle |
| */ |
| asyncifyEnabled(): boolean { |
| return this.asyncifyHandler.enabled(); |
| } |
| |
| dispose(): void { |
| this.deviceLostIsError = false; // prevent dispose to trigger device.lost error |
| // order matters |
| // ctx release goes back into lib. |
| this.ctx.dispose(); |
| this.lib.dispose(); |
| // Cannot set deviceLostIsError back to true here because GPUDevice.destroy() is asynchronous. |
| } |
| |
| /** |
| * Obtain the runtime information in readable format. |
| */ |
| runtimeStatsText(): string { |
| if (this.lib.webGPUContext !== undefined) { |
| return this.lib.webGPUContext.runtimeStatsText(); |
| } else { |
| return ""; |
| } |
| } |
| |
| /** |
| * Begin a new scope for tracking object disposal. |
| */ |
| beginScope(): void { |
| this.ctx.beginScope(); |
| } |
| |
| /** |
| * End a scope and release all created TVM objects |
| * under the current scope. |
| * |
| * Exception: one can call {@link moveToParentScope} to move |
| * a value to parent scope. |
| */ |
| endScope(): void { |
| this.ctx.endScope(); |
| } |
| |
| /** |
| * Perform action under a new scope. |
| * |
| * @param action The action function. |
| * @returns The result value. |
| * |
| * @note For action to return a valid value, |
| * we will need to call {@link moveToParentScope} |
| * for the objects that are created in the scope. |
| */ |
| withNewScope<T>(action: () => T): T { |
| this.beginScope(); |
| const val = action(); |
| this.endScope(); |
| return val; |
| } |
| |
| /** |
| * Attach a detached obj to the auto-release pool of the current scope. |
| * |
| * @param obj The input obj. |
| * @note Normally user do not need to call this function explicitly, as |
| * all library call return values are explicitly attached to |
| * the current scope. You only need to do so when you call |
| * {@link detachFromCurrentScope} to create a detached object. |
| */ |
| attachToCurrentScope<T extends Disposable>(obj: T): T { |
| return this.ctx.attachToCurrentScope(obj); |
| } |
| |
| /** |
| * Move obj's attachment to the parent scope. |
| * |
| * This function is useful to make sure objects are still |
| * alive when exit the current scope. |
| * |
| * @param obj The object to be moved. |
| * @returns The input obj. |
| */ |
| moveToParentScope<T extends Disposable>(obj: T): T { |
| return this.ctx.moveToParentScope(obj); |
| } |
| |
| /** |
| * Detach the object from the current scope |
| * so it won't be released via auto-release during endscope. |
| * |
| * User needs to either explicitly call obj.dispose(), or |
| * {@link attachToCurrentScope} to re-attach to the current scope. |
| * |
| * This function can be used to return values to the parent scope. |
| * @param obj The object. |
| */ |
| detachFromCurrentScope<T extends Disposable>(obj: T): T { |
| return this.ctx.detachFromCurrentScope(obj); |
| } |
| |
| /** |
| * Get system-wide library module in the wasm. |
| * System lib is a global module that contains self register functions in startup. |
| * @returns The system library module. |
| */ |
| systemLib(): Module { |
| return this.ctx.getSysLib() as Module; |
| } |
| /** |
| * List all the global function names registered in the runtime. |
| * @returns The name list. |
| */ |
| listGlobalFuncNames(): Array<string> { |
| return this.withNewScope(() => { |
| const functor = this.ctx.functionListGlobalNamesFunctor(); |
| const numNames = functor(new Scalar(-1, "int")) as number; |
| const names = new Array<string>(numNames); |
| for (let i = 0; i < numNames; i++) { |
| names[i] = functor(new Scalar(i, "int")) as string; |
| } |
| return names; |
| }); |
| } |
| /** |
| * Register function to be global function in tvm runtime. |
| * @param name The name of the function. |
| * @param f function to be registered. |
| * @param override Whether overwrite function in existing registry. |
| */ |
| registerFunc( |
| name: string, |
| func: PackedFunc | Function, |
| override = false |
| ): void { |
| this.withNewScope(() => { |
| const autoAttachToScope = true; |
| // packed func can be released once it is registered |
| const packedFunc = this.toPackedFuncInternal(func, autoAttachToScope); |
| const ioverride = override ? 1 : 0; |
| |
| const stack = this.lib.getOrAllocCallStack(); |
| const nameOffset = stack.allocByteArrayForString(name); |
| stack.commitToWasmMemory(); |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFIFunctionSetGlobal as ctypes.FTVMFFIFunctionSetGlobal)( |
| stack.ptrFromOffset(nameOffset), |
| packedFunc._tvmPackedCell.getHandle(), |
| ioverride |
| ) |
| ); |
| this.lib.recycleCallStack(stack); |
| }); |
| } |
| |
| /** |
| * Get global PackedFunc from the runtime. |
| * @param name The name of the function. |
| * @param autoAttachToScope Whether to track it via autoDispose |
| * @returns The result function. |
| */ |
| getGlobalFunc(name: string): PackedFunc { |
| return this.getGlobalFuncInternal(name, true); |
| } |
| |
| private getGlobalFuncInternal(name: string, autoAttachToScope = true): PackedFunc { |
| const stack = this.lib.getOrAllocCallStack(); |
| const nameOffset = stack.allocByteArrayForString(name); |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| |
| stack.commitToWasmMemory(outOffset); |
| |
| this.lib.checkCall( |
| (this.exports.TVMFFIFunctionGetGlobal as ctypes.FTVMFFIFunctionGetGlobal)( |
| stack.ptrFromOffset(nameOffset), |
| outPtr |
| ) |
| ); |
| const handle = this.memory.loadPointer(outPtr); |
| this.lib.recycleCallStack(stack); |
| if (handle === 0) { |
| throw Error("Cannot find global function " + name); |
| } |
| const ret = this.makePackedFunc(handle); |
| if (autoAttachToScope) this.ctx.attachToCurrentScope(ret); |
| return ret; |
| } |
| |
| /** |
| * Check if func is PackedFunc. |
| * |
| * @param func The input. |
| * @returns The check result. |
| */ |
| isPackedFunc(func: unknown): boolean { |
| // eslint-disable-next-line no-prototype-builtins |
| return typeof func === "function" && func.hasOwnProperty("_tvmPackedCell"); |
| } |
| |
| /** |
| * Convert func to PackedFunc |
| * |
| * @param func Input function. |
| * @returns The converted function. |
| */ |
| toPackedFunc(func: Function): PackedFunc { |
| return this.toPackedFuncInternal(func, true); |
| } |
| |
| private toPackedFuncInternal(func: Function, autoAttachToScope: boolean): PackedFunc { |
| if (this.isPackedFunc(func)) return func as PackedFunc; |
| const ret = this.createPackedFuncFromSafeCallType(this.wrapJSFuncAsSafeCallType(func)); |
| if (autoAttachToScope) return this.ctx.attachToCurrentScope(ret); |
| return ret; |
| } |
| |
| /** |
| * Setup a virtual machine module with given device. |
| * |
| * @param dev DLDevice the device. |
| * @returns The created virtual machime. |
| */ |
| createVirtualMachine(dev: DLDevice): VirtualMachine { |
| const mod = this.ctx.detachFromCurrentScope( |
| this.systemLib().getFunction("vm_load_executable")() |
| ); |
| return this.ctx.attachToCurrentScope( |
| new VirtualMachine(mod, dev) |
| ); |
| } |
| |
| //----------------------------------------------- |
| // Native Tensor Cache Support |
| //----------------------------------------------- |
| /** |
| * Register a call back for fetch progress. |
| * |
| * @param cb the fetch progress callback. |
| */ |
| registerInitProgressCallback(cb: InitProgressCallback) { |
| this.initProgressCallback.push(cb); |
| } |
| |
| /** |
| * Get parameters in the form of prefix_i |
| * |
| * @param prefix The parameter prefix. |
| * @param numParams Number of parameters. |
| * @returns |
| */ |
| getParamsFromCache(prefix: string, numParams: number): TVMObject { |
| return (this.ctx.paramModuleFromCache( |
| prefix, new Scalar(numParams, "int32")) as Module).getFunction("get_params")(); |
| } |
| |
| /** |
| * Get parameters based on parameter names provided |
| * |
| * @param paramNames Names of the parameters. |
| * @returns Parameters read. |
| */ |
| getParamsFromCacheByName(paramNames: Array<string>): TVMObject { |
| return (this.ctx.paramModuleFromCacheByName(paramNames) as Module).getFunction("get_params")(); |
| } |
| |
| /** |
| * Get Tensor from cache. |
| * @param name The name of array. |
| * @returns The result. |
| */ |
| tensorCacheGet(name: string): Tensor | undefined { |
| return this.ctx.tensorCacheGet(name); |
| } |
| |
| /** |
| * Get Tensor from cache. |
| * @param name The name of array. |
| * @returns The result. |
| */ |
| tensorCacheRemove(name: string): Tensor | undefined { |
| return this.ctx.tensorCacheRemove(name); |
| } |
| |
| /** |
| * Update the tensor cache. |
| * @param name The name of the array. |
| * @param arr The content. |
| */ |
| tensorCacheUpdate(name: string, arr: Tensor, override = false) { |
| this.ctx.tensorCacheUpdate(name, arr, this.scalar(override ? 1 : 0, "int32")); |
| } |
| |
| /** |
| * Update the tensor cache. |
| * @param name The name of the array. |
| * @param arr The content. |
| */ |
| tensorCacheClear() { |
| this.ctx.tensorCacheClear(); |
| } |
| |
| /** |
| * Given cacheUrl, search up items to fetch based on cacheUrl/tensor-cache.json |
| * |
| * @param tensorCacheUrl The cache url. |
| * @param device The device to be fetched to. |
| * @param cacheScope The scope identifier of the cache |
| * @param cacheType The type of the cache: "cache" or "indexedDB" |
| * @param signal An optional AbortSignal to abort the fetch |
| * @returns The meta data |
| */ |
| async fetchTensorCache( |
| tensorCacheUrl: string, |
| device: DLDevice, |
| cacheScope = "tvmjs", |
| cacheType = "cache", |
| signal?: AbortSignal, |
| ): Promise<any> { |
| let artifactCache: ArtifactCacheTemplate; |
| if (cacheType === undefined || cacheType.toLowerCase() === "cache") { |
| artifactCache = new ArtifactCache(cacheScope); |
| } else if (cacheType.toLowerCase() == "indexeddb") { |
| artifactCache = new ArtifactIndexedDBCache(cacheScope); |
| } else { |
| console.error("Unsupported cacheType: " + cacheType + ", using default ArtifactCache."); |
| artifactCache = new ArtifactCache(cacheScope); |
| } |
| const jsonUrl = new URL("tensor-cache.json", tensorCacheUrl).href; |
| const list = await artifactCache.fetchWithCache(jsonUrl, "json"); |
| await this.fetchTensorCacheInternal( |
| tensorCacheUrl, |
| list["records"] as Array<TensorShardEntry>, device, artifactCache, |
| signal); |
| this.cacheMetadata = { ...this.cacheMetadata, ...(list["metadata"] as Record<string, any>) }; |
| } |
| |
| |
| /** |
| * Fetch list of Tensor into the TensorCache. |
| * |
| * @param tensorCacheUrl The cache url. |
| * @param list The list of array data. |
| * @param device The device to store the data to. |
| * @param artifactCache The artifact cache |
| * @param signal An optional AbortSignal to abort the fetch |
| */ |
| private async fetchTensorCacheInternal( |
| tensorCacheUrl: string, |
| list: Array<TensorShardEntry>, |
| device: DLDevice, |
| artifactCache: ArtifactCacheTemplate, |
| signal?: AbortSignal, |
| ) { |
| const perf = compact.getPerformance(); |
| const tstart = perf.now(); |
| let totalBytes = 0; |
| for (let i = 0; i < list.length; ++i) { |
| totalBytes += list[i].nbytes; |
| } |
| let fetchedBytes = 0; |
| let fetchedShards = 0; |
| let timeElapsed = 0; |
| |
| const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, tensorCacheUrl).href)); |
| |
| // `loading`: we have finished downloading (or already cacheOnly) and are loading onto WebGPU |
| const reportCallback = (iter: number, loading = false) => { |
| // report |
| for (let j = 0; j < this.initProgressCallback.length; ++j) { |
| let text: string; |
| if (loading) { |
| text = "Loading model from cache[" + iter + "/" + list.length + "]: "; |
| text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. " |
| text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " |
| text += timeElapsed + " secs elapsed."; |
| } else { |
| text = "Fetching param cache[" + iter + "/" + list.length + "]: "; |
| text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. " |
| text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, " |
| text += timeElapsed + " secs elapsed."; |
| text += " It can take a while when we first visit this page to populate the cache." |
| text += " Later refreshes will become faster."; |
| } |
| this.initProgressCallback[j]({ |
| progress: fetchedBytes / totalBytes, |
| timeElapsed: timeElapsed, |
| text: text |
| }); |
| } |
| }; |
| |
| for (let j = 0; j < this.initProgressCallback.length; ++j) { |
| this.initProgressCallback[j]({ |
| progress: fetchedBytes / totalBytes, |
| timeElapsed: 0, |
| text: "Start to fetch params", |
| }); |
| } |
| |
| // First download all shards to cache parallely if not yet in cache |
| const downloadCache = async (start: number, end: number) => { |
| // Download params [start, end) from `list` |
| for (let i = start; i < end; i++) { |
| const shard = list[i]; |
| const dataUrl = new URL(shard.dataPath, tensorCacheUrl).href; |
| try { |
| await artifactCache.addToCache(dataUrl, "arraybuffer", signal); |
| } catch (err) { |
| this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); |
| throw err; |
| } |
| timeElapsed = Math.ceil((perf.now() - tstart) / 1000); |
| fetchedBytes += shard.nbytes; |
| reportCallback(fetchedShards++, /*loading=*/false); |
| } |
| } |
| // We launch 4 parallel for loops to limit the max concurrency to 4 download |
| if (!cacheOnly) { |
| const loopSize = Math.floor(list.length / 4); |
| await Promise.all([ |
| downloadCache(0, loopSize), |
| downloadCache(loopSize, 2 * loopSize), |
| downloadCache(2 * loopSize, 3 * loopSize), |
| downloadCache(3 * loopSize, list.length) |
| ]); |
| } |
| |
| // Then iteratively, load the shard from cache |
| for (let i = 0; i < list.length; ++i) { |
| const shard = list[i]; |
| const dataUrl = new URL(shard.dataPath, tensorCacheUrl).href; |
| let buffer; |
| try { |
| buffer = await artifactCache.fetchWithCache(dataUrl, "arraybuffer"); |
| } catch (err) { |
| this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err); |
| throw err; |
| } |
| const shardRecords = shard.records; |
| for (let j = 0; j < shardRecords.length; ++j) { |
| try { |
| const rec = shardRecords[j]; |
| const cpu_arr = this.withNewScope(() => { |
| return this.detachFromCurrentScope( |
| this.empty(rec.shape, rec.dtype, this.cpu()) |
| ) |
| }); |
| const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); |
| // first sync copy to cpu. |
| this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); |
| // then async stream into GPU if needed |
| if (device.deviceType === DeviceStrToEnum.cpu) { |
| this.tensorCacheUpdate(rec.name, cpu_arr, false); |
| cpu_arr.dispose(); |
| } else { |
| // allocate a gpu arr and async copy to it. |
| const gpu_arr = this.withNewScope(() => { |
| return this.detachFromCurrentScope( |
| this.empty(rec.shape, rec.dtype, device) |
| ) |
| }); |
| gpu_arr.copyFrom(cpu_arr); |
| await device.sync(); |
| this.tensorCacheUpdate(rec.name, gpu_arr, false); |
| cpu_arr.dispose(); |
| gpu_arr.dispose(); |
| } |
| } catch (err) { |
| this.env.logger( |
| "Failed to load shard " + i + "'s record: " + JSON.stringify(shardRecords[j]) + "\n" + |
| "Error: " + err |
| ); |
| throw err; |
| } |
| } |
| reportCallback(i + 1, /*loading=*/true); |
| } |
| } |
| |
| /** |
| * Create a new {@link Scalar} that can be passed to a PackedFunc. |
| * @param value The number value. |
| * @param dtype The dtype string. |
| * @returns The created scalar. |
| */ |
| scalar(value: number, dtype: string): Scalar { |
| return new Scalar(value, dtype); |
| } |
| |
| /** |
| * Create a new {@link DLDevice} |
| * @param deviceType The device type. |
| * @param deviceId The device index. |
| * @returns The created device. |
| */ |
| device(deviceType: number | string, deviceId = 0): DLDevice { |
| return new DLDevice(deviceType, deviceId, this.lib); |
| } |
| |
| /** |
| * Create a new cpu {@link DLDevice} |
| * @param deviceId The device index. |
| */ |
| cpu(deviceId = 0): DLDevice { |
| return this.device("cpu", deviceId); |
| } |
| |
| /** |
| * Create a new webgpu {@link DLDevice} |
| * @param deviceId The device index. |
| */ |
| webgpu(deviceId = 0): DLDevice { |
| return this.device("webgpu", deviceId); |
| } |
| |
| /** |
| * Create an empty {@link Tensor} with given shape and dtype. |
| * |
| * @param shape The shape of the array. |
| * @param dtype The data type of the array. |
| * @param dev The device of the ndarray. |
| * @returns The created ndarray. |
| */ |
| empty( |
| shape: Array<number> | number, |
| dtype: string | DLDataType = "float32", |
| dev: DLDevice = this.device("cpu", 0) |
| ): Tensor { |
| shape = typeof shape === "number" ? [shape] : shape; |
| return this.ctx.tensorEmpty(this.makeShapeTuple(shape), dtype, dev, null); |
| } |
| |
| /** |
| * Create am uniform {@link Tensor} with given shape. |
| * |
| * @param shape The shape of the array. |
| * @param low The low value. |
| * @param high The high value. |
| * @param dev The device of the ndarray. |
| * @returns The created ndarray. |
| */ |
| uniform( |
| shape: Array<number>, |
| low: number, |
| high: number, |
| dev: DLDevice |
| ): Tensor { |
| const ret = this.empty(shape, "float32", dev); |
| const size = shape.reduce((a, b) => { |
| return a * b; |
| }, 1); |
| const scale = high - low; |
| const input = new Float32Array(size); |
| for (let i = 0; i < input.length; ++i) { |
| input[i] = low + this.rng.randomFloat() * scale; |
| } |
| return ret.copyFrom(input); |
| } |
| |
| /** |
| * Set the seed of the internal LinearCongruentialGenerator. |
| */ |
| setSeed(seed: number): void { |
| this.rng.setSeed(seed); |
| } |
| |
| /** |
| * Sample index via top-p sampling. |
| * |
| * @param logits The input logits before normalization. |
| * @param temperature The temperature factor, will take argmax if temperature = 0.0 |
| * @param top_p The top_p |
| * @returns The sampled index. |
| */ |
| sampleTopPFromLogits(logits: Tensor, temperature: number, top_p: number): number { |
| return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, this.rng.randomFloat()); |
| } |
| |
| /** |
| * Sample index via top-p sampling. |
| * |
| * @param prob The distribution, i.e. logits after `applySoftmaxWithTemperature()` is performed. |
| * @param top_p The top_p |
| * @returns The sampled index. |
| */ |
| sampleTopPFromProb(prob: Tensor, top_p: number): number { |
| return this.ctx.sampleTopPFromProb(prob, top_p, this.rng.randomFloat()); |
| } |
| |
| /** |
| * Apply repetition penalty to the logits. |
| * @param logits The input logits before penalty. |
| * @param token_ids The appeared token ids. |
| * @param penalty The penalty factor. |
| */ |
| applyRepetitionPenalty(logits: Tensor, token_ids: Tensor, penalty: number) { |
| return this.ctx.applyRepetitionPenalty(logits, token_ids, penalty); |
| } |
| |
| /** |
| * Apply presence and frequency penalty. This is an inplace operation. |
| * @param logits The input logits before penalty. |
| * @param token_ids The appeared token ids. |
| * @param token_freqs The number of times each token has appeared since last PrefillStep. |
| * token_freqs[i] is the frequency of token_ids[i], for all i. And all token_freqs should be >= 1. |
| * @param presence_penalty The penalty factor. |
| * @param frequency_penalty The penalty factor. |
| */ |
| applyPresenceAndFrequencyPenalty( |
| logits: Tensor, |
| token_ids: Tensor, |
| token_freqs: Tensor, |
| presence_penalty: number, |
| frequency_penalty: number |
| ) { |
| return this.ctx.applyPresenceAndFrequencyPenalty( |
| logits, token_ids, token_freqs, presence_penalty, frequency_penalty |
| ); |
| } |
| |
| /** |
| * Apply softmax with temperature to the logits. |
| * @param logits The input logits before softmax w/ temperature. |
| * @param temperature The temperature factor. |
| */ |
| applySoftmaxWithTemperature(logits: Tensor, temperature: number) { |
| return this.ctx.applySoftmaxWithTemperature(logits, temperature); |
| } |
| |
| /** |
| * Bind canvas to the current WebGPU context |
| * @param canvas The canvas. |
| */ |
| bindCanvas(canvas: HTMLCanvasElement) { |
| this.lib.webGPUContext?.bindCanvas(canvas); |
| } |
| |
| /** |
| * Show image in canvas. |
| * |
| * @param dataRGBA Image array in height x width uint32 Tensor RGBA format on GPU. |
| */ |
| showImage(dataRGBA: Tensor) { |
| if (dataRGBA.shape.length != 2) { |
| throw Error("Require a height x width uint32 Tensor in RGBA" + |
| "get shape=" + dataRGBA.shape.toString() + " instead." |
| ); |
| } |
| if (dataRGBA.device.deviceType != DeviceStrToEnum.webgpu) { |
| throw new Error("Can only run showImage on WebGPU array, " + |
| "get " + DeviceEnumToStr[dataRGBA.device.deviceType] + " instead."); |
| } |
| if (dataRGBA.dtype != "uint32") { |
| throw Error("Require a height x width uint32 Tensor in RGBA, " + |
| "get " + dataRGBA.dtype + " instead."); |
| } |
| this.lib.webGPUContext?.drawImageFromBuffer( |
| dataRGBA.getDataPtr(), dataRGBA.shape[0], dataRGBA.shape[1] |
| ); |
| } |
| |
| /** |
| * Clear canvas |
| */ |
| clearCanvas() { |
| this.lib.webGPUContext?.clearCanvas(); |
| } |
| |
| /** |
| * Create an tuple {@link TVMArray} input array. |
| * |
| * The input array can be passed to tvm runtime function |
| * and needs to b explicitly disposed. |
| * |
| * @param inputs The input array |
| * @returns The result array. |
| */ |
| makeTVMArray( |
| inputs: Array<any> |
| ): TVMArray { |
| const CALL_STACK_LIMIT = 30000; |
| const inputsLength = inputs.length; |
| if (inputsLength <= CALL_STACK_LIMIT) { |
| return this.ctx.arrayMake(...inputs) as TVMArray; |
| } |
| // If too many elements, TypeScript would complain `Maximum call stack size exceeded` |
| // So we make several arrays and concatenate them |
| const listOfArrays: Array<TVMArray> = []; |
| for (let begin = 0; begin < inputsLength; begin += CALL_STACK_LIMIT) { |
| const end = Math.min(inputsLength, begin + CALL_STACK_LIMIT); |
| const chunk: Array<any> = inputs.slice(begin, end); |
| listOfArrays.push(this.ctx.arrayMake(...chunk) as TVMArray); |
| } |
| return this.ctx.arrayConcat(...listOfArrays) as TVMArray; |
| } |
| |
| /** |
| * Join a sequence of Tensors that represent embeddings. |
| * @param inputs A list of embeddings in Tensors, each array i has shape (m_i, hidden_size). |
| * @returns An Tensor of shape (\sum_{i} {m}, hidden_size) |
| */ |
| concatEmbeddings(embeddings: Array<Tensor>): Tensor { |
| // 1. Check shape validity |
| const hidden_size = embeddings[0].shape[1]; |
| embeddings.forEach((input) => { |
| if (input.shape.length !== 2 || input.shape[1] !== hidden_size) { |
| throw new Error("Expect embeddings to concatenate have shape (m_i, hidden_size)."); |
| } |
| }) |
| |
| // 2. Call global func |
| if (this.ctx.concatEmbeddings === undefined) { |
| throw new Error( |
| "Global function tvmjs.runtime.ConcatEmbeddings was " + |
| "not found, but called concatEmbeddings." |
| ); |
| } |
| return this.ctx.concatEmbeddings(...embeddings) as Tensor; |
| } |
| |
| /** |
| * Create a shape tuple to pass to runtime. |
| * @param shape The shape . |
| * @returns The created shape tuple. |
| */ |
| makeShapeTuple(shape: Array<number>): TVMObject { |
| const shapeArray = shape.map((value) => new Scalar(value, "int")); |
| return this.ctx.makeShapeTuple(...shapeArray); |
| } |
| /** |
| * Get type index from type key. |
| * @param typeKey The type key. |
| * @returns The corresponding type index. |
| */ |
| typeKey2Index( |
| typeKey: string |
| ): number { |
| const stack = this.lib.getOrAllocCallStack(); |
| const typeKeyOffset = stack.allocByteArrayForString(typeKey); |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| |
| stack.commitToWasmMemory(outOffset); |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFITypeKeyToIndex as ctypes.FTVMFFITypeKeyToIndex)( |
| stack.ptrFromOffset(typeKeyOffset), |
| outPtr |
| ) |
| ); |
| const typeIndex = this.memory.loadU32(outPtr); |
| this.lib.recycleCallStack(stack); |
| return typeIndex; |
| } |
| |
| /** |
| * Register an object constructor. |
| * @param typeKey The name of the function. |
| * @param func Function to be registered. |
| * @param override Whether overwrite function in existing registry. |
| */ |
| registerObjectConstructor( |
| typeKey: string, |
| func: FObjectConstructor, |
| override = false |
| ): void { |
| const typeIndex = this.typeKey2Index(typeKey); |
| if (this.objFactory.has(typeIndex)) { |
| if (!override) { |
| throw new Error("Type " + typeKey + " already registered"); |
| } |
| } |
| this.objFactory.set(typeIndex, func); |
| } |
| |
| /** |
| * Wrap a function obtained from tvm runtime as AsyncPackedFunc |
| * through the asyncify mechanism |
| * |
| * You only need to call it if the function may contain callback into async |
| * JS function via asynctify. A common one can be GPU synchronize. |
| * |
| * It is always safe to wrap any function as Asynctify, however you do need |
| * to make sure you use await when calling the funciton. |
| * |
| * @param func The PackedFunc. |
| * @returns The wrapped AsyncPackedFunc |
| */ |
| wrapAsyncifyPackedFunc(func: PackedFunc): AsyncPackedFunc { |
| const asyncFunc = this.asyncifyHandler.wrapExport(func) as AsyncPackedFunc; |
| asyncFunc.dispose = func.dispose; |
| asyncFunc._tvmPackedCell = func._tvmPackedCell; |
| return asyncFunc; |
| } |
| |
| /** |
| * Register async function as asynctify callable in global environment. |
| * |
| * @param name The name of the function. |
| * @param func function to be registered. |
| * @param override Whether overwrite function in existing registry. |
| * |
| * @note This function is handled via asynctify mechanism |
| * The wasm needs to be compiled with Asynctify |
| */ |
| registerAsyncifyFunc( |
| name: string, |
| func: (...args: Array<any>) => Promise<any>, |
| override = false |
| ): void { |
| const asyncWrapped = this.asyncifyHandler.wrapImport(func); |
| this.registerFunc(name, asyncWrapped, override); |
| } |
| |
| /** |
| * Register an asyncfunction to be global function in the server. |
| * |
| * @param name The name of the function. |
| * @param func function to be registered. |
| * @param override Whether overwrite function in existing registry. |
| * |
| * @note The async function will only be used for serving remote calls in the rpc |
| * These functions contains explicit continuation |
| */ |
| registerAsyncServerFunc( |
| name: string, |
| func: Function, |
| override = false |
| ): void { |
| const asyncVariant = (...args: Array<any>): void => { |
| const fargs = args.slice(0, args.length - 1); |
| // need to keep it alive until callback is fulfilled. |
| const callback = this.detachFromCurrentScope(args[args.length - 1] as PackedFunc); |
| const promise: Promise<any> = func(...fargs); |
| const onFulfilled = (rv: any) => { |
| callback(this.scalar(AsyncCallbackCode.kReturn, "int32"), rv); |
| callback.dispose(); |
| }; |
| const onRejected = (reason: any) => { |
| callback(this.scalar(AsyncCallbackCode.kException, "int32"), reason.toString()); |
| callback.dispose(); |
| }; |
| promise.then(onFulfilled, onRejected); |
| }; |
| this.registerFunc("__async." + name, asyncVariant, override); |
| } |
| |
| /** |
| * Asynchronously load webgpu pipelines when possible. |
| * @param mod The input module. |
| */ |
| async asyncLoadWebGPUPipelines(mod: Module): Promise<void> { |
| if (this.lib.webGPUContext === undefined) throw Error("WebGPU not initialied"); |
| const webgpuContext = this.lib.webGPUContext; |
| |
| this.beginScope(); |
| const fmap_str = mod.getFunction("webgpu.get_fmap", true)() as string; |
| const fmap: Record<string, FunctionInfo> = JSON.parse(fmap_str); |
| const fGetShader = this.detachFromCurrentScope( |
| mod.getFunction("webgpu.get_shader") |
| ); |
| const fUpdatePrebuild = this.detachFromCurrentScope( |
| mod.getFunction("webgpu.update_prebuild") |
| ); |
| this.endScope(); |
| |
| const perf = compact.getPerformance(); |
| const tstart = perf.now(); |
| let tlastReport = tstart; |
| let finishCounter = 0; |
| const fmapEntries = Object.entries(fmap); |
| |
| let allEvents = Promise.resolve(); |
| |
| for (const [key, finfo] of fmapEntries) { |
| const code = fGetShader(key); |
| assert(key === finfo.name); |
| const event = webgpuContext.createShaderAsync(finfo, code).then((func) => { |
| this.beginScope(); |
| fUpdatePrebuild(key, func); |
| this.endScope(); |
| |
| }).then(() => { |
| finishCounter += 1; |
| const tend = perf.now(); |
| // skip report if gap is smaller than 1000 |
| if ((tend - tlastReport) < 1000 && finishCounter != fmapEntries.length) { |
| return; |
| } |
| tlastReport = tend; |
| const timeElapsed = Math.ceil((perf.now() - tstart) / 1000); |
| // report |
| for (let j = 0; j < this.initProgressCallback.length; ++j) { |
| const progress = finishCounter / fmapEntries.length; |
| let text = "Loading GPU shader modules[" + finishCounter + "/" + fmapEntries.length + "]: "; |
| text += Math.floor(progress * 100).toString() + "% completed, " |
| text += timeElapsed + " secs elapsed."; |
| this.initProgressCallback[j]({ |
| progress: progress, |
| timeElapsed: timeElapsed, |
| text: text |
| }); |
| } |
| }); |
| allEvents = Promise.all([allEvents, event]).then(() => { }); |
| } |
| await allEvents; |
| assert(finishCounter === fmapEntries.length); |
| } |
| |
| /** |
| * Initialize webgpu in the runtime. |
| * @param device The given GPU device. |
| */ |
| initWebGPU(device: GPUDevice): void { |
| device.addEventListener("uncapturederror", (event) => { |
| console.error("A WebGPU error was not captured: ", event); |
| }); |
| |
| device.lost.then((info: any) => { |
| if (this.deviceLostIsError) { |
| console.error("Device lost, calling Instance.dispose(). Please initialize again. ", info); |
| this.dispose(); |
| } |
| }); |
| this.deviceLostIsError = true; |
| |
| const webGPUContext = new WebGPUContext( |
| this.memory, device |
| ); |
| this.registerFunc("wasm.WebGPUDeviceAPI", (name: string) => { |
| return webGPUContext.getDeviceAPI(name); |
| }); |
| this.registerFunc("wasm.WebGPUCreateShader", (info: string, code: string) => { |
| const finfo = JSON.parse(info) as FunctionInfo; |
| return webGPUContext.createShader(finfo, code); |
| }); |
| this.registerAsyncServerFunc("wasm.WebGPUWaitForTasks", async () => { |
| await webGPUContext.sync(); |
| }); |
| if (this.asyncifyHandler.enabled()) { |
| this.registerAsyncifyFunc("__asyncify.WebGPUWaitForTasks", async () => { |
| await webGPUContext.sync(); |
| }); |
| } |
| this.lib.webGPUContext = webGPUContext; |
| } |
| |
| /** Register all object factory */ |
| private registerObjectFactoryFuncs(): void { |
| this.registerObjectConstructor("ffi.Array", |
| (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { |
| return new TVMArray(handle, lib, ctx); |
| }); |
| this.registerObjectConstructor("ffi.Module", |
| (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { |
| return new Module(handle, lib, ctx); |
| }); |
| } |
| |
| /** Register global packed functions needed by the backend to the env. */ |
| private registerEnvGlobalPackedFuncs(): void { |
| // Register the timer function to enable the time_evaluator. |
| const perf = compact.getPerformance(); |
| |
| // Helper function to time the finvoke |
| const timeExecution = async ( |
| finvoke: PackedFunc, |
| dev: DLDevice, |
| nstep: number, |
| repeat: number, |
| minRepeatMs: number, |
| limitZeroTimeIterations: number, |
| cooldownIntervalMs: number, |
| repeatsToCooldown: number |
| ): Promise<Uint8Array> => { |
| // detach and explicit dispose when tasks is fullfilled |
| // the promise will immediately return and we need to makesure |
| // finvoke do not get recycled. |
| this.ctx.detachFromCurrentScope(finvoke); |
| |
| finvoke(this.scalar(1, "int32")); |
| await dev.sync(); |
| const result = []; |
| let setupNumber: number = nstep; |
| |
| for (let i = 0; i < repeat; ++i) { |
| let durationMs = 0.0; |
| let absoluteZeroTimes = 0; |
| do { |
| if (durationMs > 0.0) { |
| const golden_ratio = 1.618; |
| setupNumber = Math.floor( |
| Math.max(minRepeatMs / (durationMs / setupNumber) + 1, setupNumber * golden_ratio) |
| ); |
| } |
| const tstart: number = perf.now(); |
| finvoke(this.scalar(setupNumber, "int32")); |
| await dev.sync(); |
| const tend: number = perf.now(); |
| |
| durationMs = tend - tstart; |
| if (durationMs === 0) { |
| absoluteZeroTimes++; |
| } |
| } while (durationMs < minRepeatMs && absoluteZeroTimes < limitZeroTimeIterations); |
| const speed = durationMs / setupNumber / 1000; |
| result.push(speed); |
| if (cooldownIntervalMs > 0.0 && (i % repeatsToCooldown) === 0) { |
| await new Promise(r => setTimeout(r, cooldownIntervalMs)); |
| } |
| } |
| const ret = new Float64Array(result.length); |
| ret.set(result); |
| |
| // dispose finvoke |
| finvoke.dispose(); |
| return new Uint8Array(ret.buffer); |
| }; |
| |
| const addOne = async (x: number): Promise<number> => { |
| await new Promise(resolve => setTimeout(resolve, 100)); |
| return x + 1; |
| }; |
| |
| this.registerAsyncServerFunc("wasm.TimeExecution", timeExecution); |
| this.registerAsyncServerFunc("testing.asyncAddOne", addOne); |
| } |
| |
| private createPackedFuncFromSafeCallType( |
| func: ctypes.FTVMFFIWasmSafeCallType |
| ): PackedFunc { |
| let findex = this.env.packedCFuncTable.length; |
| if (this.env.packedCFuncTableFreeId.length != 0) { |
| findex = this.env.packedCFuncTableFreeId.pop() as number; |
| } else { |
| this.env.packedCFuncTable.push(undefined); |
| } |
| this.env.packedCFuncTable[findex] = func; |
| |
| const stack = this.lib.getOrAllocCallStack(); |
| const outOffset = stack.allocPtrArray(1); |
| const outPtr = stack.ptrFromOffset(outOffset); |
| this.lib.checkCall( |
| (this.exports |
| .TVMFFIWasmFunctionCreate as ctypes.FTVMFFIWasmFunctionCreate)( |
| findex, |
| outPtr |
| ) |
| ); |
| const ret = this.makePackedFunc(this.memory.loadPointer(outPtr)); |
| this.lib.recycleCallStack(stack); |
| return ret; |
| } |
| |
| /** |
| * Set packed function arguments into the location indicated by argsValue and argsCode. |
| * Allocate new temporary space from the stack if necessary. |
| * |
| * @parma stack The call stack |
| * @param args The input arguments. |
| * @param packedArgs The offset of packedArgs. |
| */ |
| setPackedArguments( |
| stack: CachedCallStack, |
| args: Array<any>, |
| packedArgs: PtrOffset, |
| ): void { |
| for (let i = 0; i < args.length; ++i) { |
| let val = args[i]; |
| const tp = typeof val; |
| const argOffset = packedArgs + i * SizeOf.TVMFFIAny; |
| const argTypeIndexOffset = argOffset; |
| const argZeroPaddingOffset = argOffset + SizeOf.I32; |
| const argValueOffset = argOffset + SizeOf.I32 * 2; |
| |
| // Convert string[] to a TVMArray of, hence treated as a TVMObject |
| if (val instanceof Array && val.every(e => typeof e === "string")) { |
| const tvmStringArray: string[] = []; |
| val.forEach(e => { tvmStringArray.push(e) }); |
| val = this.makeTVMArray(tvmStringArray); |
| } |
| |
| // clear off the extra zero padding before ptr storage |
| stack.storeI32(argZeroPaddingOffset, 0); |
| // clear off the extra zero padding after ptr storage |
| stack.storeI32(argValueOffset + SizeOf.I32, 0); |
| if (val instanceof Tensor) { |
| if (!val.isView) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFITensor); |
| stack.storePtr(argValueOffset, val.getHandle()); |
| } else { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDLTensorPtr); |
| stack.storePtr(argValueOffset, val.getHandle()); |
| } |
| } else if (val instanceof Scalar) { |
| if (val.dtype.startsWith("int") || val.dtype.startsWith("uint")) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIInt); |
| stack.storeI64(argValueOffset, val.value); |
| } else if (val.dtype.startsWith("float")) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat); |
| stack.storeF64(argValueOffset, val.value); |
| } else { |
| assert(val.dtype === "handle", "Expect handle"); |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIOpaquePtr); |
| stack.storePtr(argValueOffset, val.value); |
| } |
| } else if (val instanceof DLDevice) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIDevice); |
| stack.storeI32(argValueOffset, val.deviceType); |
| stack.storeI32(argValueOffset + SizeOf.I32, val.deviceId); |
| } else if (tp === "boolean") { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIBool); |
| stack.storeI64(argValueOffset, val ? 1 : 0); |
| } else if (tp === "number") { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFloat); |
| stack.storeF64(argValueOffset, val); |
| // eslint-disable-next-line no-prototype-builtins |
| } else if (tp === "function" && val.hasOwnProperty("_tvmPackedCell")) { |
| stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle()); |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction); |
| } else if (val === null || val === undefined) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFINone); |
| stack.storePtr(argValueOffset, 0); |
| } else if (tp === "string") { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIRawStr); |
| stack.allocThenSetArgString(argValueOffset, val); |
| } else if (val instanceof Uint8Array) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIByteArrayPtr); |
| stack.allocThenSetArgBytes(argValueOffset, val); |
| } else if (val instanceof Function) { |
| val = this.toPackedFuncInternal(val, false); |
| stack.tempArgs.push(val); |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIFunction); |
| stack.storePtr(argValueOffset, val._tvmPackedCell.getHandle()); |
| } else if (val instanceof Module) { |
| stack.storeI32(argTypeIndexOffset, TypeIndex.kTVMFFIModule); |
| stack.storePtr(argValueOffset, val.getHandle()); |
| } else if (val instanceof TVMObject) { |
| stack.storeI32(argTypeIndexOffset, val.typeIndex()); |
| stack.storePtr(argValueOffset, val.getHandle()); |
| } else { |
| throw new Error("Unsupported argument type " + tp + " value=`" + val.toString() + "`"); |
| } |
| } |
| } |
| |
| private wrapJSFuncAsSafeCallType(func: Function): ctypes.FTVMFFIWasmSafeCallType { |
| const lib = this.lib; |
| return ( |
| // eslint-disable-next-line @typescript-eslint/no-unused-vars |
| self: Pointer, |
| packedArgs: Pointer, |
| numArgs: number, |
| ret: Pointer |
| ): number => { |
| const jsArgs = []; |
| // use scope to track js values. |
| this.ctx.beginScope(); |
| for (let i = 0; i < numArgs; ++i) { |
| const argPtr = packedArgs + i * SizeOf.TVMFFIAny; |
| const typeIndex = lib.memory.loadI32(argPtr); |
| |
| if (typeIndex >= TypeIndex.kTVMFFIRawStr) { |
| // NOTE: the following code have limitations in asyncify mode. |
| // The reason is that the TVMFFIAnyViewToOwnedAny will simply |
| // get skipped during the rewinding process, causing memory failure |
| if (!this.asyncifyHandler.isNormalStackState()) { |
| throw Error("Cannot handle str/object argument callback in asyncify mode"); |
| } |
| lib.checkCall( |
| (lib.exports.TVMFFIAnyViewToOwnedAny as ctypes.FTVMFFIAnyViewToOwnedAny)( |
| argPtr, |
| argPtr |
| ) |
| ); |
| } |
| jsArgs.push(this.retValueToJS(argPtr, true)); |
| } |
| |
| let rv: any; |
| try { |
| rv = func(...jsArgs); |
| } catch (error) { |
| // error handling |
| // store error via SetLastError |
| this.ctx.endScope(); |
| const errKind = "JSCallbackError" |
| const errMsg = error.message; |
| const stack = lib.getOrAllocCallStack(); |
| const errKindOffset = stack.allocRawBytes(errKind.length + 1); |
| stack.storeRawBytes(errKindOffset, StringToUint8Array(errKind)); |
| const errMsgOffset = stack.allocRawBytes(errMsg.length + 1); |
| stack.storeRawBytes(errMsgOffset, StringToUint8Array(errMsg)); |
| stack.commitToWasmMemory(); |
| (this.lib.exports.FTVMFFIErrorSetRaisedFromCStr as ctypes.FTVMFFIErrorSetRaisedFromCStr)( |
| stack.ptrFromOffset(errKindOffset), |
| stack.ptrFromOffset(errMsgOffset) |
| ); |
| this.lib.recycleCallStack(stack); |
| return -1; |
| } |
| |
| // normal return path |
| // recycle all js object value in function unless we want to retain them. |
| this.ctx.endScope(); |
| if (rv !== undefined && rv !== null) { |
| const stack = lib.getOrAllocCallStack(); |
| const argOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); |
| this.setPackedArguments(stack, [rv], argOffset); |
| stack.commitToWasmMemory(); |
| const argPtr = stack.ptrFromOffset(argOffset); |
| lib.checkCall( |
| (lib.exports.TVMFFIAnyViewToOwnedAny as ctypes.FTVMFFIAnyViewToOwnedAny)( |
| argPtr, |
| ret |
| ) |
| ); |
| lib.recycleCallStack(stack); |
| } |
| return 0; |
| }; |
| } |
| |
| private makePackedFunc(handle: Pointer): PackedFunc { |
| const cell = new PackedFuncCell(handle, this.lib, this.ctx); |
| const packedFunc = (...args: any): any => { |
| const stack = this.lib.getOrAllocCallStack(); |
| const argsOffset = stack.allocRawBytes(SizeOf.TVMFFIAny * args.length); |
| this.setPackedArguments(stack, args, argsOffset); |
| const retOffset = stack.allocRawBytes(SizeOf.TVMFFIAny); |
| // pre-store the result to be null |
| stack.storeI32(retOffset, TypeIndex.kTVMFFINone); |
| // clear off the extra zero padding before ptr storage |
| stack.storeI32(retOffset + SizeOf.I32, 0); |
| stack.commitToWasmMemory(); |
| this.lib.checkCall( |
| (this.exports.TVMFFIFunctionCall as ctypes.FTVMFFIFunctionCall)( |
| cell.getHandle(), |
| stack.ptrFromOffset(argsOffset), |
| args.length, |
| stack.ptrFromOffset(retOffset) |
| ) |
| ); |
| |
| const ret = this.retValueToJS(stack.ptrFromOffset(retOffset), false); |
| this.lib.recycleCallStack(stack); |
| return ret; |
| }; |
| // Attach attributes to the function type. |
| // This is because javascript do not allow us to overload call. |
| const ret: any = packedFunc; |
| ret.dispose = (): void => { |
| cell.dispose(); |
| }; |
| ret._tvmPackedCell = cell; |
| return ret as PackedFunc; |
| } |
| |
| /** |
| * Creaye return value of the packed func. The value us auto-tracked for dispose. |
| * @param resultAnyPtr The location of rvalue |
| * @param callbackArg Whether it is being used in callbackArg. |
| * @returns The JS value. |
| */ |
| private retValueToJS(resultAnyPtr: Pointer, callbackArg: boolean): any { |
| const typeIndex = this.memory.loadI32(resultAnyPtr); |
| const valuePtr = resultAnyPtr + SizeOf.I32 * 2; |
| switch (typeIndex) { |
| case TypeIndex.kTVMFFINone: return undefined; |
| case TypeIndex.kTVMFFIBool: |
| return this.memory.loadI64(valuePtr) != 0; |
| case TypeIndex.kTVMFFIInt: |
| return this.memory.loadI64(valuePtr); |
| case TypeIndex.kTVMFFIFloat: |
| return this.memory.loadF64(valuePtr); |
| case TypeIndex.kTVMFFIOpaquePtr: { |
| return this.memory.loadPointer(valuePtr); |
| } |
| case TypeIndex.kTVMFFITensor: { |
| return this.ctx.attachToCurrentScope( |
| new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false) |
| ); |
| } |
| case TypeIndex.kTVMFFIDLTensorPtr: { |
| assert(callbackArg); |
| // no need to attach as we are only looking at view |
| return new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, true); |
| } |
| case TypeIndex.kTVMFFIFunction: { |
| return this.ctx.attachToCurrentScope( |
| this.makePackedFunc(this.memory.loadPointer(valuePtr)) |
| ); |
| } |
| case TypeIndex.kTVMFFIDevice: { |
| const deviceType = this.memory.loadI32(valuePtr); |
| const deviceId = this.memory.loadI32(valuePtr + SizeOf.I32); |
| return this.device(deviceType, deviceId); |
| } |
| case TypeIndex.kTVMFFIDataType: { |
| // simply return dtype as tring to keep things simple |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFIDataTypeToString as ctypes.FTVMFFIDataTypeToString)(valuePtr, valuePtr) |
| ); |
| const strObjPtr = this.memory.loadPointer(valuePtr); |
| const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) |
| ); |
| return result; |
| } |
| case TypeIndex.kTVMFFISmallStr: { |
| return this.memory.loadSmallStr(resultAnyPtr); |
| } |
| case TypeIndex.kTVMFFIStr: { |
| const strObjPtr = this.memory.loadPointer(valuePtr); |
| const result = this.memory.loadByteArrayAsString(strObjPtr + SizeOf.ObjectHeader); |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(strObjPtr) |
| ); |
| return result; |
| } |
| case TypeIndex.kTVMFFISmallBytes: { |
| return this.memory.loadSmallBytes(resultAnyPtr); |
| } |
| case TypeIndex.kTVMFFIBytes: { |
| const bytesObjPtr = this.memory.loadPointer(valuePtr); |
| const result = this.memory.loadByteArrayAsBytes(bytesObjPtr + SizeOf.ObjectHeader); |
| this.lib.checkCall( |
| (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(bytesObjPtr) |
| ); |
| return result; |
| } |
| default: { |
| if (typeIndex >= TypeIndex.kTVMFFIStaticObjectBegin) { |
| const obj = new TVMObject( |
| this.memory.loadPointer(valuePtr), |
| this.lib, |
| this.ctx |
| ); |
| const func = this.objFactory.get(obj.typeIndex()) |
| if (func != undefined) { |
| return this.ctx.attachToCurrentScope( |
| func(obj.getHandle(), this.lib, this.ctx) |
| ); |
| } else { |
| return this.ctx.attachToCurrentScope(obj); |
| } |
| } else { |
| throw new Error("Unsupported return type code=" + typeIndex); |
| } |
| } |
| } |
| } |
| } |
| |
| /** |
| * Asynchrously instantiate a new {@link Instance}. |
| * |
| * importObject can also be a {@link LibraryProvider} object, |
| * a WASI object, or an object containing wasmLibraryProvider field. |
| * We can take benefit of syslib implementations from the Emscripten |
| * by passing its generated js Module as the imports. |
| * |
| * @param bufferSource The source to be compiled. |
| * @param importObject The import objects. |
| * @param logger The system logger. |
| */ |
| export function instantiate( |
| bufferSource: ArrayBuffer, |
| importObject: Record<string, any> = {}, |
| logger: (msg: string) => void = console.log |
| ): Promise<Instance> { |
| const env = new Environment(importObject, logger); |
| |
| return WebAssembly.instantiate(bufferSource, env.imports).then( |
| (result: WebAssembly.WebAssemblyInstantiatedSource): Instance => { |
| return new Instance(result.module, {}, result.instance, env); |
| } |
| ); |
| } |