blob: 12f044ab50ce2aa598385dcbc0aada04138c9538 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import argparse
import pickle
import socket
import struct
from google.protobuf.message import Message
from singa import tensor
def receive_all(conn: socket.socket, size: int) -> bytes:
"""Receive a given length of bytes from socket.
Args:
conn (socket.socket): Socket connection.
size (int): Length of bytes to receive.
Raises:
RuntimeError: If connection closed before chunk was read, it will raise an error.
Returns:
bytes: Received bytes.
"""
buffer = b""
while size > 0:
chunk = conn.recv(size)
if not chunk:
raise RuntimeError("connection closed before chunk was read")
buffer += chunk
size -= len(chunk)
return buffer
def send_int(conn: socket.socket, i: int, pack_format: str = "Q") -> None:
"""Send an integer from socket.
Args:
conn (socket.socket): Socket connection.
i (int): Integer to send.
pack_format (str, optional): Pack format. Defaults to "Q", which means unsigned long long.
"""
data = struct.pack(f"!{pack_format}", i)
conn.sendall(data)
def receive_int(conn: socket.socket, pack_format: str = "Q") -> int:
"""Receive an integer from socket.
Args:
conn (socket.socket): Socket connection.
pack_format (str, optional): Pack format. Defaults to "Q", which means unsigned long long.
Returns:
int: Received integer.
"""
buffer_size = struct.Struct(pack_format).size
data = receive_all(conn, buffer_size)
(data,) = struct.unpack(f"!{pack_format}", data)
return data
def send_message(conn: socket.socket, data: Message, pack_format: str = "Q") -> None:
"""Send protobuf message from socket. First the length of protobuf message will be sent. Then the message is sent.
Args:
conn (socket.socket): Socket connection.
data (Message): Protobuf message to send.
pack_format (str, optional): Length of protobuf message pack format. Defaults to "Q", which means unsigned long long.
"""
send_int(conn, data.ByteSize(), pack_format)
conn.sendall(data.SerializePartialToString())
def receive_message(conn: socket.socket, data: Message, pack_format: str = "Q") -> Message:
"""Receive protobuf message from socket
Args:
conn (socket.socket): Socket connection.
data (Message): Placehold for protobuf message.
pack_format (str, optional): Length of protobuf message pack format. Defaults to "Q", which means unsigned long long.
Returns:
Message: The protobuf message.
"""
data_len = receive_int(conn, pack_format)
data.ParseFromString(receive_all(conn, data_len))
return data
def serialize_tensor(t: tensor.Tensor) -> bytes:
"""Serialize a singa tensor to bytes.
Args:
t (tensor.Tensor): The singa tensor.
Returns:
bytes: The serialized tensor.
"""
return pickle.dumps(tensor.to_numpy(t), protocol=0)
def deserialize_tensor(t: bytes) -> tensor.Tensor:
"""Recover singa tensor from bytes.
Args:
t (bytes): The serialized tensor.
Returns:
tensor.Tensor: The singa tensor.
"""
return tensor.from_numpy(pickle.loads(t))
def parseargs(arg=None) -> argparse.Namespace:
"""Parse command line arguments
Returns:
argparse.Namespace: parsed arguments
"""
parser = argparse.ArgumentParser(description="Training using the autograd and graph.")
parser.add_argument(
"--model", choices=["cnn", "resnet", "xceptionnet", "mlp", "alexnet"], default="mlp"
)
parser.add_argument("--data", choices=["mnist", "cifar10", "cifar100", "bank"], default="mnist")
parser.add_argument("-p", choices=["float32", "float16"], default="float32", dest="precision")
parser.add_argument(
"-m", "--max-epoch", default=10, type=int, help="maximum epochs", dest="max_epoch"
)
parser.add_argument(
"-b", "--batch-size", default=64, type=int, help="batch size", dest="batch_size"
)
parser.add_argument(
"-l", "--learning-rate", default=0.005, type=float, help="initial learning rate", dest="lr"
)
# Determine which gpu to use
parser.add_argument(
"-i", "--device-id", default=0, type=int, help="which GPU to use", dest="device_id"
)
parser.add_argument(
"-g",
"--disable-graph",
default="True",
action="store_false",
help="disable graph",
dest="graph",
)
parser.add_argument(
"-v", "--log-verbosity", default=0, type=int, help="logging verbosity", dest="verbosity"
)
parser.add_argument(
"-d",
"--data-distribution",
choices=["iid", "non-iid"],
default="iid",
help="data distribution",
dest="data_dist",
)
parser.add_argument("--num_clients", default=10, type=int)
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", default=1234)
args = parser.parse_args(arg)
return args