Half precision training mang lại các lợi ích:
Theo tiêu chuẩn IEEE 754 chỉ rõ binary16 có các định dạng sau: Sign bit: 1 bit Exponent width: 5 bits Significand precision: 11 bits (lưu trữ chính xác 10)
Tải dữ liệu ở dạng fp32 và dễ dàng đổi sang fp16 bằng casting.
>>> from singa import tensor, device >>> dev = device.create_cuda_gpu() >>> x = tensor.random((2,3),dev) >>> x [[0.7703407 0.42764223 0.5872884 ] [0.78362167 0.70469785 0.64975065]], float32 >>> y = x.as_type(tensor.float16) >>> y [[0.7705 0.4277 0.5874] [0.7837 0.7046 0.65 ]], float16
Các chương trình cơ bản được hỗ trợ với fp16.
>>> y+y [[1.541 0.8555 1.175 ] [1.567 1.409 1.3 ]], float16
Training theo half precision được thực hiện đơn giản qua ba bước:
# cast dữ liệu đầu vào sang fp16 x = load_data() x = x.astype(np.float16) tx = tensor.from_numpy(x) # tải model model = build_model() # chuyển optimizer dtype sang fp16 sgd = opt.SGD(lr=0.1, dtype=tensor.float16) # train như bình thường out, loss = model(tx, ty)
Tập tin ví dụ train_cnn.py
, chạy lệnh dưới đây để train theo half.
python examples/cnn/train_cnn.py cnn mnist -pfloat16
Thực hiện theo dạng half được tích hợp ở C++ backend như hỗ trợ dạng half nói chung.
Để chạy trên GPU, __half
có trên API của Cuda math. Để hỗ trợ chạy __half
math, cần compile với Nvidia compute arch > 6.0 (Pascal).
Tensor Core phát hành bởi Nvidia gia tốc half precision và các throughput cho các hàm như GEMM(CuBlas) và convolution(CuDNN). Khi kích hoạt hàm Tensor core, có một vài hạn chế về quy mô GEMM, kích thước kênh convolution, phiên bản Cuda, và phiên bản GPU (Turing hoặc mới hơn) v.v.
Hàm Half cơ bản được thực hiện trong tensor_math_cuda.h
, bằng cách chuyên môn hoá mô hình thực hiện với half type và áp dụng low level computation.
Ví dụ, hàm GEMM được thực hiện như sau:
template <> void GEMM<half_float::half, lang::Cuda>(const half_float::half alpha, const Tensor& A, const Tensor& B, const half_float::half beta, Tensor* C, Context* ctx) { // ... CUBLAS_CHECK(cublasGemmEx(handle, transb, transa, ncolB, nrowA, ncolA, alphaPtr, BPtr, Btype, ldb, APtr, Atype, lda, betaPtr, CPtr, Ctype, ldc, computeType, algo)); // ... }