blob: 7450cc1cfc66011edf5cd77dfc58f5a0a2f1f817 [file] [log] [blame]
#!/usr/bin/env python3
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# modified from https://github.com/zhengzangw/Fed-SINGA/blob/main/src/server/app.py
import socket
from collections import defaultdict
from typing import Dict, List
from singa import tensor
from .proto import interface_pb2 as proto
from .proto.utils import parseargs
from .proto import utils
class Server:
"""Server sends and receives protobuf messages.
Create and start the server, then use pull and push to communicate with clients.
Attributes:
num_clients (int): Number of clients.
host (str): Host address of the server.
port (str): Port of the server.
sock (socket.socket): Socket of the server.
conns (List[socket.socket]): List of num_clients sockets.
addrs (List[str]): List of socket address.
weights (Dict[Any]): Weights stored on server.
"""
def __init__(
self,
num_clients=1,
host: str = "127.0.0.1",
port: str = 1234,
) -> None:
"""Class init method
Args:
num_clients (int, optional): Number of clients in training.
host (str, optional): Host ip address. Defaults to '127.0.0.1'.
port (str, optional): Port. Defaults to 1234.
"""
self.num_clients = num_clients
self.host = host
self.port = port
self.sock = socket.socket()
self.conns = [None] * num_clients
self.addrs = [None] * num_clients
self.weights = {}
def __start_connection(self) -> None:
"""Start the network connection of server."""
self.sock.bind((self.host, self.port))
self.sock.listen()
print("Server started.")
def __start_rank_pairing(self) -> None:
"""Start pair each client to a global rank"""
for _ in range(self.num_clients):
conn, addr = self.sock.accept()
rank = utils.receive_int(conn)
self.conns[rank] = conn
self.addrs[rank] = addr
print(f"[Server] Connected by {addr} [global_rank {rank}]")
assert None not in self.conns
def start(self) -> None:
"""Start the server.
This method will first bind and listen on the designated host and port.
Then it will connect to num_clients clients and maintain the socket.
In this process, each client shall provide their rank number.
"""
self.__start_connection()
self.__start_rank_pairing()
def close(self) -> None:
"""Close the server."""
self.sock.close()
def aggregate(self, weights: Dict[str, List[tensor.Tensor]]) -> Dict[str, tensor.Tensor]:
"""Aggregate collected weights to update server weight.
Args:
weights (Dict[str, List[tensor.Tensor]]): The collected weights.
Returns:
Dict[str, tensor.Tensor]: Updated weight stored in server.
"""
for k, v in weights.items():
self.weights[k] = sum(v) / self.num_clients
return self.weights
def pull(self) -> None:
"""Server pull weights from clients.
Namely clients push weights to the server. It is the gather process.
"""
# open space to collect weights from clients
datas = [proto.WeightsExchange() for _ in range(self.num_clients)]
weights = defaultdict(list)
# receive weights sequentially
for i in range(self.num_clients):
datas[i] = utils.receive_message(self.conns[i], datas[i])
for k, v in datas[i].weights.items():
weights[k].append(utils.deserialize_tensor(v))
# aggregation
self.aggregate(weights)
def push(self) -> None:
"""Server push weights to clients.
Namely clients pull weights from server. It is the scatter process.
"""
message = proto.WeightsExchange()
message.op_type = proto.SCATTER
for k, v in self.weights.items():
message.weights[k] = utils.serialize_tensor(v)
for conn in self.conns:
utils.send_message(conn, message)
if __name__ == "__main__":
args = parseargs()
server = Server(num_clients=args.num_clients, host=args.host, port=args.port)
server.start()
for i in range(args.max_epoch):
print(f"On epoch {i}:")
if i > 0:
# Push to Clients
server.push()
# Collects from Clients
server.pull()
server.close()