blob: d5b737f2aef5b4f161990f08d83e97638cf9b570 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import * as http from 'http';
import * as net from 'net';
import WebSocket from 'ws';
import { v4 as uuidv4 } from 'uuid';
import jwt from 'jsonwebtoken';
import cookie from 'cookie';
import Redis from 'ioredis';
import StatsD from 'hot-shots';
import { createLogger } from './logger';
import { buildConfig } from './config';
export type StreamResult = [
recordId: string,
record: [label: 'data', data: string],
];
// sync with superset-frontend/src/components/ErrorMessage/types
export type ErrorLevel = 'info' | 'warning' | 'error';
export type SupersetError<ExtraType = Record<string, any> | null> = {
error_type: string;
extra: ExtraType;
level: ErrorLevel;
message: string;
};
type ListenerFunction = (results: StreamResult[]) => void;
interface EventValue {
id: string;
channel_id: string;
job_id: string;
user_id?: string;
status: string;
errors?: SupersetError[];
result_url?: string;
}
interface JwtPayload {
channel: string;
}
interface FetchRangeFromStreamParams {
sessionId: string;
startId: string;
endId: string;
listener: ListenerFunction;
}
export interface SocketInstance {
ws: WebSocket;
channel: string;
pongTs: number;
}
interface RedisConfig {
port: number;
host: string;
password?: string | null;
db: number;
ssl: boolean;
}
interface ChannelValue {
sockets: Array<string>;
}
const environment = process.env.NODE_ENV;
const startServer = process.argv[2] === 'start';
export const opts = buildConfig();
// init logger
const logger = createLogger({
silent: environment === 'test',
logLevel: opts.logLevel,
logToFile: opts.logToFile,
logFilename: opts.logFilename,
});
export const statsd = new StatsD({
...opts.statsd,
errorHandler: (e: Error) => {
logger.error(e);
},
});
// enforce JWT secret length
if (startServer && opts.jwtSecret.length < 32)
throw new Error('Please provide a JWT secret at least 32 bytes long');
export const redisUrlFromConfig = (redisConfig: RedisConfig): string => {
let url = redisConfig.ssl ? 'rediss://' : 'redis://';
if (redisConfig.password) url += `:${redisConfig.password}@`;
url += `${redisConfig.host}:${redisConfig.port}/${redisConfig.db}`;
return url;
};
// initialize servers
const redis = new Redis(redisUrlFromConfig(opts.redis));
const httpServer = http.createServer();
export const wss = new WebSocket.Server({
noServer: true,
clientTracking: false,
});
const SOCKET_ACTIVE_STATES = [WebSocket.OPEN, WebSocket.CONNECTING];
const GLOBAL_EVENT_STREAM_NAME = `${opts.redisStreamPrefix}full`;
const DEFAULT_STREAM_LAST_ID = '$';
// initialize internal registries
export let channels: Record<string, ChannelValue> = {};
export let sockets: Record<string, SocketInstance> = {};
let lastFirehoseId: string = DEFAULT_STREAM_LAST_ID;
export const setLastFirehoseId = (id: string): void => {
lastFirehoseId = id;
};
/**
* Adds the passed channel and socket instance to the internal registries.
*/
export const trackClient = (
channel: string,
socketInstance: SocketInstance,
): string => {
statsd.increment('ws_connected_client');
const socketId = uuidv4();
sockets[socketId] = socketInstance;
if (channel in channels) {
channels[channel].sockets.push(socketId);
} else {
channels[channel] = { sockets: [socketId] };
}
return socketId;
};
/**
* Sends a single async event payload to a single channel.
* A channel may have multiple connected sockets, this emits
* the event to all connected sockets within a channel.
*/
export const sendToChannel = (channel: string, value: EventValue): void => {
const strData = JSON.stringify(value);
if (!channels[channel]) {
logger.debug(`channel ${channel} is unknown, skipping`);
return;
}
channels[channel].sockets.forEach(socketId => {
const socketInstance: SocketInstance = sockets[socketId];
if (!socketInstance) return cleanChannel(channel);
try {
socketInstance.ws.send(strData);
} catch (err) {
statsd.increment('ws_client_send_error');
logger.debug(`Error sending to socket: ${err}`);
// check that the connection is still active
cleanChannel(channel);
}
});
};
/**
* Reads a range of events from a channel-specific Redis event stream.
* Invoked in the client re-connection flow.
*/
export const fetchRangeFromStream = async ({
sessionId,
startId,
endId,
listener,
}: FetchRangeFromStreamParams) => {
const streamName = `${opts.redisStreamPrefix}${sessionId}`;
try {
const reply = await redis.xrange(streamName, startId, endId);
if (!reply || !reply.length) return;
listener(reply as StreamResult[]);
} catch (e) {
logger.error(e);
}
};
/**
* Reads from the global Redis event stream continuously.
* Utilizes a blocking connection to Redis to wait for data to
* be returned from the stream.
*/
export const subscribeToGlobalStream = async (
stream: string,
listener: ListenerFunction,
) => {
/*eslint no-constant-condition: ["error", { "checkLoops": false }]*/
while (true) {
try {
const reply = await redis.xread(
'BLOCK',
opts.redisStreamReadBlockMs,
'COUNT',
opts.redisStreamReadCount,
'STREAMS',
stream,
lastFirehoseId,
);
if (!reply) {
continue;
}
const results = reply[0][1];
const { length } = results;
if (!results.length) {
continue;
}
listener(results as StreamResult[]);
setLastFirehoseId(results[length - 1][0]);
} catch (e) {
logger.error(e);
continue;
}
}
};
/**
* Callback function to process events received from a Redis Stream
*/
export const processStreamResults = (results: StreamResult[]): void => {
logger.debug(`events received: ${results}`);
results.forEach(item => {
try {
const id = item[0];
const data = JSON.parse(item[1][1]);
sendToChannel(data.channel_id, { id, ...data });
} catch (err) {
logger.error(err);
}
});
};
/**
* Verify and parse a JWT cookie from an HTTP request.
* Returns the JWT payload or throws an error on invalid token.
*/
const getJwtPayload = (request: http.IncomingMessage): JwtPayload => {
const cookies = cookie.parse(request.headers.cookie || '');
const token = cookies[opts.jwtCookieName];
if (!token) throw new Error('JWT not present');
return jwt.verify(token, opts.jwtSecret) as JwtPayload;
};
/**
* Extracts the `last_id` query param value from an HTTP request
*/
const getLastId = (request: http.IncomingMessage): string | null => {
const url = new URL(String(request.url), 'http://0.0.0.0');
const queryParams = url.searchParams;
return queryParams.get('last_id');
};
/**
* Increments a Redis Stream ID
*/
export const incrementId = (id: string): string => {
// redis stream IDs are in this format: '1607477697866-0'
const parts = id.split('-');
if (parts.length < 2) return id;
return parts[0] + '-' + (Number(parts[1]) + 1);
};
/**
* WebSocket `connection` event handler, called via wss
*/
export const wsConnection = (ws: WebSocket, request: http.IncomingMessage) => {
const jwtPayload: JwtPayload = getJwtPayload(request);
const channel: string = jwtPayload.channel;
const socketInstance: SocketInstance = { ws, channel, pongTs: Date.now() };
// add this ws instance to the internal registry
const socketId = trackClient(channel, socketInstance);
logger.debug(`socket ${socketId} connected on channel ${channel}`);
// reconnection logic
const lastId = getLastId(request);
if (lastId) {
// fetch range of events from lastId to most recent event received on
// via global event stream
const endId =
lastFirehoseId === DEFAULT_STREAM_LAST_ID ? '+' : lastFirehoseId;
fetchRangeFromStream({
sessionId: channel,
startId: incrementId(lastId), // inclusive
endId, // inclusive
listener: processStreamResults,
});
}
// init event handler for `pong` events (connection management)
ws.on('pong', function pong(data: Buffer) {
const socketId = data.toString();
const socketInstance = sockets[socketId];
if (!socketInstance) {
logger.warn(`pong received for nonexistent socket ${socketId}`);
} else {
socketInstance.pongTs = Date.now();
}
});
};
/**
* HTTP `request` event handler, called via httpServer
*/
export const httpRequest = (
request: http.IncomingMessage,
response: http.ServerResponse,
) => {
const rawUrl = request.url as string;
const method = request.method as string;
const headers = request.headers || {};
const url = new URL(rawUrl as string, `http://${headers.host}`);
if (url.pathname === '/health' && ['GET', 'HEAD'].includes(method)) {
response.writeHead(200);
response.end('OK');
} else {
logger.info(`Received unexpected request: ${method} ${rawUrl}`);
response.writeHead(404);
response.end('Not Found');
}
};
/**
* HTTP `upgrade` event handler, called via httpServer
*/
export const httpUpgrade = (
request: http.IncomingMessage,
socket: net.Socket,
head: Buffer,
) => {
try {
const jwtPayload: JwtPayload = getJwtPayload(request);
if (!jwtPayload.channel) throw new Error('Channel ID not present');
} catch (err) {
// JWT invalid, do not establish a WebSocket connection
logger.error(err);
socket.destroy();
return;
}
// upgrade the HTTP request into a WebSocket connection
wss.handleUpgrade(
request,
socket,
head,
function cb(ws: WebSocket, request: http.IncomingMessage) {
wss.emit('connection', ws, request);
},
);
};
// Connection cleanup and garbage collection
/**
* Iterate over all tracked sockets, terminating and removing references to
* connections that have not responded with a _pong_ within the timeout window.
* Sends a _ping_ to all active connections.
*/
export const checkSockets = () => {
logger.debug(`channel count: ${Object.keys(channels).length}`);
logger.debug(`socket count: ${Object.keys(sockets).length}`);
for (const socketId in sockets) {
const socketInstance = sockets[socketId];
const timeout = Date.now() - socketInstance.pongTs;
let isActive = true;
if (timeout >= opts.socketResponseTimeoutMs) {
logger.debug(
`terminating unresponsive socket: ${socketId}, channel: ${socketInstance.channel}`,
);
socketInstance.ws.terminate();
isActive = false;
} else if (!SOCKET_ACTIVE_STATES.includes(socketInstance.ws.readyState)) {
isActive = false;
}
if (isActive) {
socketInstance.ws.ping(socketId);
} else {
delete sockets[socketId];
logger.debug(`forgetting socket ${socketId}`);
}
}
};
/**
* Iterate over all sockets within a channel, removing references to
* inactive connections, ultimately removing the channel from the
* _channels_ registry if no active connections remain.
*/
export const cleanChannel = (channel: string) => {
const activeSockets: string[] =
channels[channel]?.sockets.filter(socketId => {
const socketInstance = sockets[socketId];
if (!socketInstance) return false;
if (SOCKET_ACTIVE_STATES.includes(socketInstance.ws.readyState))
return true;
return false;
}) || [];
if (activeSockets.length === 0) {
delete channels[channel];
} else {
channels[channel].sockets = activeSockets;
}
};
// server startup
if (startServer) {
// init server event listeners
wss.on('connection', wsConnection);
httpServer.on('request', httpRequest);
httpServer.on('upgrade', httpUpgrade);
httpServer.listen(opts.port);
logger.info(`Server started on port ${opts.port}`);
// start reading from event stream
subscribeToGlobalStream(GLOBAL_EVENT_STREAM_NAME, processStreamResults);
// init garbage collection routines
setInterval(checkSockets, opts.pingSocketsIntervalMs);
setInterval(function gc() {
// clean all channels
for (const channel in channels) {
cleanChannel(channel);
}
}, opts.gcChannelsIntervalMs);
}
// test utilities
export const resetState = () => {
channels = {};
sockets = {};
lastFirehoseId = DEFAULT_STREAM_LAST_ID;
};