blob: d308e43c9f4d6bb128506947a0b722a16d2f3da4 [file] [log] [blame]
// Copyright 2021-2023 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import type {
AnyMessage,
Message,
MethodInfo,
PartialMessage,
ServiceType,
} from "@bufbuild/protobuf";
import { MethodIdempotency } from "@bufbuild/protobuf";
import { requestHeaderWithCompression } from "./request-header.js";
import { headerUnaryContentLength, headerUnaryEncoding } from "./headers.js";
import { validateResponseWithCompression } from "./validate-response.js";
import { trailerDemux } from "./trailer-mux.js";
import { errorFromJsonBytes } from "./error-json.js";
import { createEndStreamSerialization, endStreamFlag } from "./end-stream.js";
import { transformConnectPostToGetRequest } from "./get-request.js";
import type { CommonTransportOptions } from "../protocol/transport-options.js";
import { Code } from "../code.js";
import { DubboError } from "../dubbo-error.js";
import { appendHeaders } from "../http-headers.js";
import type {
UnaryResponse,
UnaryRequest,
StreamResponse,
StreamRequest,
} from "../interceptor.js";
import {
createAsyncIterable,
pipeTo,
sinkAllBytes,
pipe,
transformNormalizeMessage,
transformSerializeEnvelope,
transformCompressEnvelope,
transformJoinEnvelopes,
transformSplitEnvelope,
transformDecompressEnvelope,
transformParseEnvelope,
} from "../protocol/async-iterable.js";
import { createMethodUrl } from "../protocol/create-method-url.js";
import { runUnaryCall, runStreamingCall } from "../protocol/run-call.js";
import { createMethodSerializationLookup } from "../protocol/serialization.js";
import type { Transport } from "../transport.js";
import type { TripleClientServiceOptions } from './client-service-options.js';
import { ConsumerMeterCollector, createObservable } from '@apachedubbo/dubbo-observable';
import type { Observable, ObservableOptions } from '@apachedubbo/dubbo-observable';
/**
* The Observable service instance.
*/
let observable: Observable;
/**
* If the observable service is not initialized, initialize it
* @param observableOptions
*/
function initObservable(observableOptions?: ObservableOptions) {
if (!observable) {
observable = createObservable(observableOptions);
observable.start();
// TODO: observable.shutdown()
}
}
/**
* Create a Transport for the Connect protocol.
*/
export function createTransport(opt: CommonTransportOptions): Transport {
if (opt?.observableOptions?.enable) {
// Enable and init observable service.
initObservable(opt.observableOptions);
const consumerMeterCollector = new ConsumerMeterCollector({
name: "DubboJS Consumer Metrics Collector",
version: "v1.0.0"
});
opt.interceptors = opt.interceptors || [];
opt.interceptors.push((next) => async (req) => {
const contentType = req.header.get("content-type") ?? "";
const serviceVersion = req.header.get("tri-service-version") ?? "";
const serviceGroup = req.header.get("tri-service-group") ?? "";
const protocolVersion = req.header.get("tri-protocol-version") ?? "";
// Total number of collection send request
consumerMeterCollector.consumerRequest({
service: req.service.typeName,
method: req.method.name,
serviceVersion: serviceVersion,
serviceGroup: serviceGroup,
protocolVersion: protocolVersion,
protocol: contentType,
});
const start = new Date().getTime();
try {
const response = await next(req);
const end = new Date().getTime();
const rt = end - start;
// Total number of collection successful request
consumerMeterCollector.consumerRequestSucceed({
service: response.service.typeName,
method: response.method.name,
serviceVersion: serviceVersion,
serviceGroup: serviceGroup,
protocolVersion: protocolVersion,
protocol: contentType,
rt: rt
})
return response;
} catch (e) {
const end = new Date().getTime();
const rt = end - start;
// Total number of collection fail request
consumerMeterCollector.consumerRequestFailed({
service: req.service.typeName,
method: req.method.name,
serviceVersion: serviceVersion,
serviceGroup: serviceGroup,
protocolVersion: protocolVersion,
protocol: contentType,
rt: rt,
error: String(e)
})
throw e;
}
});
}
return {
async unary<
I extends Message<I> = AnyMessage,
O extends Message<O> = AnyMessage
>(
service: ServiceType,
method: MethodInfo<I, O>,
signal: AbortSignal | undefined,
timeoutMs: number | undefined,
header: HeadersInit | undefined,
message: PartialMessage<I>,
serviceOptions?: TripleClientServiceOptions
): Promise<UnaryResponse<I, O>> {
const serialization = createMethodSerializationLookup(
method,
opt.binaryOptions,
opt.jsonOptions,
opt
);
return await runUnaryCall<I, O>({
interceptors: opt.interceptors,
signal,
timeoutMs,
req: {
stream: false,
service,
method,
url: createMethodUrl(opt.baseUrl, service, method),
init: {},
header: requestHeaderWithCompression(
method.kind,
opt.useBinaryFormat,
timeoutMs,
header,
opt.acceptCompression,
opt.sendCompression,
serviceOptions
),
message:
message instanceof method.I ? message : new method.I(message),
},
next: async (req: UnaryRequest<I, O>): Promise<UnaryResponse<I, O>> => {
let requestBody = serialization
.getI(opt.useBinaryFormat)
.serialize(req.message);
if (
opt.sendCompression &&
requestBody.byteLength > opt.compressMinBytes
) {
requestBody = await opt.sendCompression.compress(requestBody);
req.header.set(headerUnaryEncoding, opt.sendCompression.name);
} else {
req.header.delete(headerUnaryEncoding);
}
const useGet =
opt.useHttpGet === true &&
method.idempotency === MethodIdempotency.NoSideEffects;
let body: AsyncIterable<Uint8Array> | undefined;
if (useGet) {
req = transformConnectPostToGetRequest(
req,
requestBody,
opt.useBinaryFormat
);
} else {
body = createAsyncIterable([requestBody]);
}
const universalResponse = await opt.httpClient({
url: req.url,
method: req.init.method ?? "POST",
header: req.header,
signal: req.signal,
body,
});
const { compression, isUnaryError, unaryError } =
validateResponseWithCompression(
method.kind,
opt.acceptCompression,
universalResponse.status,
universalResponse.header
);
const [header, trailer] = trailerDemux(universalResponse.header);
let responseBody = await pipeTo(
universalResponse.body,
sinkAllBytes(
opt.readMaxBytes,
universalResponse.header.get(headerUnaryContentLength)
),
{ propagateDownStreamError: false }
);
if (compression) {
responseBody = await compression.decompress(
responseBody,
opt.readMaxBytes
);
}
if (isUnaryError) {
throw errorFromJsonBytes(
responseBody,
appendHeaders(header, trailer),
unaryError
);
}
return <UnaryResponse<I, O>>{
stream: false,
service,
method,
header,
message: serialization
.getO(opt.useBinaryFormat)
.parse(responseBody),
trailer,
};
},
});
},
async stream<
I extends Message<I> = AnyMessage,
O extends Message<O> = AnyMessage
>(
service: ServiceType,
method: MethodInfo<I, O>,
signal: AbortSignal | undefined,
timeoutMs: number | undefined,
header: HeadersInit | undefined,
input: AsyncIterable<I>
): Promise<StreamResponse<I, O>> {
const serialization = createMethodSerializationLookup(
method,
opt.binaryOptions,
opt.jsonOptions,
opt
);
const endStreamSerialization = createEndStreamSerialization(
opt.jsonOptions
);
return runStreamingCall<I, O>({
interceptors: opt.interceptors,
signal,
timeoutMs,
req: {
stream: true,
service,
method,
url: createMethodUrl(opt.baseUrl, service, method),
init: {
method: "POST",
redirect: "error",
mode: "cors",
},
header: requestHeaderWithCompression(
method.kind,
opt.useBinaryFormat,
timeoutMs,
header,
opt.acceptCompression,
opt.sendCompression
),
message: pipe(input, transformNormalizeMessage(method.I), {
propagateDownStreamError: true,
}),
},
next: async (req: StreamRequest<I, O>) => {
const uRes = await opt.httpClient({
url: req.url,
method: "POST",
header: req.header,
signal: req.signal,
body: pipe(
req.message,
transformNormalizeMessage(method.I),
transformSerializeEnvelope(
serialization.getI(opt.useBinaryFormat)
),
transformCompressEnvelope(
opt.sendCompression,
opt.compressMinBytes
),
transformJoinEnvelopes(),
{ propagateDownStreamError: true }
),
});
const { compression } = validateResponseWithCompression(
method.kind,
opt.acceptCompression,
uRes.status,
uRes.header
);
const res: StreamResponse<I, O> = {
...req,
header: uRes.header,
trailer: new Headers(),
message: pipe(
uRes.body,
transformSplitEnvelope(opt.readMaxBytes),
transformDecompressEnvelope(
compression ?? null,
opt.readMaxBytes
),
transformParseEnvelope(
serialization.getO(opt.useBinaryFormat),
endStreamFlag,
endStreamSerialization
),
async function* (iterable) {
let endStreamReceived = false;
for await (const chunk of iterable) {
if (chunk.end) {
if (endStreamReceived) {
throw new DubboError(
"protocol error: received extra EndStreamResponse",
Code.InvalidArgument
);
}
endStreamReceived = true;
if (chunk.value.error) {
throw chunk.value.error;
}
chunk.value.metadata.forEach((value, key) =>
res.trailer.set(key, value)
);
continue;
}
if (endStreamReceived) {
throw new DubboError(
"protocol error: received extra message after EndStreamResponse",
Code.InvalidArgument
);
}
yield chunk.value;
}
if (!endStreamReceived) {
throw new DubboError(
"protocol error: missing EndStreamResponse",
Code.InvalidArgument
);
}
},
{ propagateDownStreamError: true }
),
};
return res;
},
});
},
};
}