| #!/bin/python |
| |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| # Usage: python bpacking_avx512_codegen.py > bpacking_avx512_generated.h |
| |
| |
| def print_unpack_bit_func(bit): |
| shift = 0 |
| shifts = [] |
| in_index = 0 |
| inls = [] |
| mask = (1 << bit) - 1 |
| bracket = "{" |
| |
| print( |
| f"inline static const uint32_t* unpack{bit}_32_avx512(const uint32_t* in, uint32_t* out) {bracket}") |
| print(" using ::arrow::util::SafeLoad;") |
| print(" uint32_t mask = 0x%x;" % mask) |
| print(" __m512i reg_shifts, reg_inls, reg_masks;") |
| print(" __m512i results;") |
| |
| print("") |
| for i in range(32): |
| if shift + bit == 32: |
| shifts.append(shift) |
| inls.append(f"SafeLoad(in + {in_index})") |
| in_index += 1 |
| shift = 0 |
| elif shift + bit > 32: # cross the boundary |
| inls.append( |
| f"SafeLoad(in + {in_index}) >> {shift} | SafeLoad(in + {in_index + 1}) << {32 - shift}") |
| in_index += 1 |
| shift = bit - (32 - shift) |
| shifts.append(0) # zero shift |
| else: |
| shifts.append(shift) |
| inls.append(f"SafeLoad(in + {in_index})") |
| shift += bit |
| |
| print(" reg_masks = _mm512_set1_epi32(mask);") |
| print("") |
| print(" // shift the first 16 outs") |
| print( |
| f" reg_shifts = _mm512_set_epi32({shifts[15]}, {shifts[14]}, {shifts[13]}, {shifts[12]},") |
| print( |
| f" {shifts[11]}, {shifts[10]}, {shifts[9]}, {shifts[8]},") |
| print( |
| f" {shifts[7]}, {shifts[6]}, {shifts[5]}, {shifts[4]},") |
| print( |
| f" {shifts[3]}, {shifts[2]}, {shifts[1]}, {shifts[0]});") |
| print(f" reg_inls = _mm512_set_epi32({inls[15]}, {inls[14]},") |
| print(f" {inls[13]}, {inls[12]},") |
| print(f" {inls[11]}, {inls[10]},") |
| print(f" {inls[9]}, {inls[8]},") |
| print(f" {inls[7]}, {inls[6]},") |
| print(f" {inls[5]}, {inls[4]},") |
| print(f" {inls[3]}, {inls[2]},") |
| print(f" {inls[1]}, {inls[0]});") |
| print( |
| " results = _mm512_and_epi32(_mm512_srlv_epi32(reg_inls, reg_shifts), reg_masks);") |
| print(" _mm512_storeu_si512(out, results);") |
| print(" out += 16;") |
| print("") |
| print(" // shift the second 16 outs") |
| print( |
| f" reg_shifts = _mm512_set_epi32({shifts[31]}, {shifts[30]}, {shifts[29]}, {shifts[28]},") |
| print( |
| f" {shifts[27]}, {shifts[26]}, {shifts[25]}, {shifts[24]},") |
| print( |
| f" {shifts[23]}, {shifts[22]}, {shifts[21]}, {shifts[20]},") |
| print( |
| f" {shifts[19]}, {shifts[18]}, {shifts[17]}, {shifts[16]});") |
| print(f" reg_inls = _mm512_set_epi32({inls[31]}, {inls[30]},") |
| print(f" {inls[29]}, {inls[28]},") |
| print(f" {inls[27]}, {inls[26]},") |
| print(f" {inls[25]}, {inls[24]},") |
| print(f" {inls[23]}, {inls[22]},") |
| print(f" {inls[21]}, {inls[20]},") |
| print(f" {inls[19]}, {inls[18]},") |
| print(f" {inls[17]}, {inls[16]});") |
| print( |
| " results = _mm512_and_epi32(_mm512_srlv_epi32(reg_inls, reg_shifts), reg_masks);") |
| print(" _mm512_storeu_si512(out, results);") |
| print(" out += 16;") |
| print("") |
| print(f" in += {bit};") |
| print("") |
| print(" return in;") |
| print("}") |
| |
| |
| def print_unpack_bit0_func(): |
| print( |
| "inline static const uint32_t* unpack0_32_avx512(const uint32_t* in, uint32_t* out) {") |
| print(" memset(out, 0x0, 32 * sizeof(*out));") |
| print(" out += 32;") |
| print("") |
| print(" return in;") |
| print("}") |
| |
| |
| def print_unpack_bit32_func(): |
| print( |
| "inline static const uint32_t* unpack32_32_avx512(const uint32_t* in, uint32_t* out) {") |
| print(" memcpy(out, in, 32 * sizeof(*out));") |
| print(" in += 32;") |
| print(" out += 32;") |
| print("") |
| print(" return in;") |
| print("}") |
| |
| |
| def print_copyright(): |
| print( |
| """// Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License.""") |
| |
| |
| def print_note(): |
| print("//") |
| print("// Automatically generated file; DO NOT EDIT.") |
| |
| |
| def main(): |
| print_copyright() |
| print_note() |
| print("") |
| print("#pragma once") |
| print("") |
| print("#include <stdint.h>") |
| print("#include <string.h>") |
| print("") |
| print("#ifdef _MSC_VER") |
| print("#include <intrin.h>") |
| print("#else") |
| print("#include <immintrin.h>") |
| print("#endif") |
| print("") |
| print('#include "arrow/util/ubsan.h"') |
| print("") |
| print("namespace arrow {") |
| print("namespace internal {") |
| print("") |
| print_unpack_bit0_func() |
| print("") |
| for i in range(1, 32): |
| print_unpack_bit_func(i) |
| print("") |
| print_unpack_bit32_func() |
| print("") |
| print("} // namespace internal") |
| print("} // namespace arrow") |
| |
| |
| if __name__ == '__main__': |
| main() |