an(float a, float b, float c);

    __device__ float __nv_ptx_builtin_ocg_fmin3_ftz(float a, float b, float c);
    __device__ float __nv_ptx_builtin_ocg_fmax3_ftz(float a, float b, float c);
}

#endif

inline __device__  float row_max_reduction_128_elems(uint32_t reg[128]) {
    float tmp_max_0[42];
    float tmp_max_1[14];
    float tmp_max_2[5];
    float tmp_max_3[2];

    #pragma unroll
    for (int i = 0; i < 126; i+=3) {
        tmp_max_0[i/3] = __nv_ptx_builtin_ocg_fmax3(reinterpret_cast<float &>(reg[i+0]), reinterpret_cast<float &>(reg[i+1]), reinterpret_cast<float &>(reg[i+2]));
    }
    #pragma unroll
    for (int i = 0; i < 42; i+=3) {
        tmp_max_1[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_0[i+0], tmp_max_0[i+1], tmp_max_0[i+2]);
    }
    #pragma unroll
    for (int i = 0; i < 12; i+=3) {
        tmp_max_2[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[i+0], tmp_max_1[i+1], tmp_max_1[i+2]);
    }
    tmp_max_2[4] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[12], tmp_max_1[13], reinterpret_cast<float &>(reg[126]));
    tmp_max_3[0] = __nv_ptx_builtin_ocg_fmax3(tmp_max_2[0], tmp_max_2[1], tmp_max_2[2]);
    tmp_max_3[1] = __nv_ptx_builtin_ocg_fmax3(tmp_max_2[3], tmp_max_2[4], reinterpret_cast<float &>(reg[127]));

    return __nv_ptx_builtin_ocg_fmax3(tmp_max_3[0], tmp_max_3[1], tmp_max_3[1]);
}

template<class InputType>
inline __device__  float row_max_reduction(InputType const input) {
    if constexpr (size(input) == 32) {
        float tmp_max_0[10];
        float tmp_max_1[4];
        float tmp_max_2[1];

        #pragma unroll
        for (int i = 0; i < 30; i+=3) {
            tmp_max_0[i/3] = __nv_ptx_builtin_ocg_fmax3(input(i+0), input(i+1), input(i+2));
        }
        #pragma unroll
        for (int i = 0; i < 9; i+=3) {
            tmp_max_1[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_0[i+0], tmp_max_0[i+1], tmp_max_0[i+2]);
        }
        tmp_max_1[3] = __nv_ptx_builtin_ocg_fmax3(input(30), input(31), tmp_max_0[9]);

        tmp_max_2[0] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[0], tmp_max_1[1], tmp_max_1[2]);

        return __nv_ptx_builtin_ocg_fmax3(tmp_max_2[0], tmp_max_1[3], -__FLT_MAX__);
    } else if constexpr (size(input) == 64) {
        float tmp_max_0[21];
        float tmp_max_1[7];
        float tmp_max_2[3];

        #pragma unroll
        for (int i = 0; i < 63; i+=3) {
            tmp_max_0[i/3] = __nv_ptx_builtin_ocg_fmax3(input(i+0), input(i+1), input(i+2));
        }
        #pragma unroll
        for (int i = 0; i < 21; i+=3) {
            tmp_max_1[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_0[i+0], tmp_max_0[i+1], tmp_max_0[i+2]);
        }
        #pragma unroll
        for (int i = 0; i < 6; i+=3) {
            tmp_max_2[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[i+0], tmp_max_1[i+1], tmp_max_1[i+2]);
        }
        tmp_max_2[2] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[6], input(63), -__FLT_MAX__);

        return __nv_ptx_builtin_ocg_fmax3(tmp_max_2[0], tmp_max_2[1], tmp_max_2[2]);
    } else if constexpr (size(input) == 128) {
        float tmp_max_0[42];
        float tmp_max_1[14];
        float tmp_max_2[5];
        float tmp_max_3[2];

        #pragma unroll
        for (int i = 0; i < 126; i+=3) {
            tmp_max_0[i/3] = __nv_ptx_builtin_ocg_fmax3(input(i+0), input(i+1), input(i+2));
        }
        #pragma unroll
        for (int i = 0; i < 42; i+=3) {
            tmp_max_1[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_0[i+0], tmp_max_0[i+1], tmp_max_0[i+2]);
        }
        #pragma unroll
        for (int i = 0; i < 12; i+=3) {
            tmp_max_2[i/3] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[i+0], tmp_max_1[i+1], tmp_max_1[i+2]);
        }
        tmp_max_2[4] = __nv_ptx_builtin_ocg_fmax3(tmp_max_1[12], tmp_max_1[13], input(126));
        tmp_max_3[0] = __nv_ptx_builtin_ocg_fmax3(tmp_max_2[0], tmp_max_2[1], tmp_max_2[2]);
        tmp_max_3[1] = __nv_ptx_builtin_ocg_fmax3(tmp_max_2[3], tmp_max_2[4], input(127));

        return __nv_ptx_builtin_ocg_fmax3(tmp_max_3[0], tmp_max_3[1], -__FLT_MAX__);
    }
    return 0.f;
}
