 3) |
        (SPARSE_ENABLED << 2) |
        (SPARSE_METADATA_ID2 << 0)) << 32;
}

template<uint64_t UTCMMA_M,
         uint64_t UTCMMA_N,
         uint64_t B_TRANSPOSE,
         uint64_t A_TRANSPOSE,
         uint64_t B_NEGATE,
         uint64_t A_NEGATE,
         uint64_t SRC_B_TYPE,
         uint64_t SRC_A_TYPE,
         uint64_t SFA_ID,
         uint64_t SFB_ID,
         uint64_t SF_FORMAT,
         uint64_t SPARSE_METADATA_FORMAT,
         uint64_t SPARSE_ENABLED>
inline __device__ constexpr uint64_t build_block_scale_utcmma_instruction_desc() {
    return (static_cast<uint64_t>(0) |
        (SFA_ID << 29) |
        ((UTCMMA_M >> 5) << 25) |
        (SF_FORMAT << 23) |
        ((UTCMMA_N >> 3) << 17) |
        (B_TRANSPOSE << 16) |
        (A_TRANSPOSE << 15) |
        (B_NEGATE << 14) |
        (A_NEGATE << 13) |
        (SRC_B_TYPE << 10) |
        (SRC_A_TYPE << 7) |
        (SPARSE_METADATA_FORMAT << 6) |
        (SFB_ID << 4) |
        (SPARSE_ENABLED << 2)) << 32;
}

class Smem_utcmma_descriptor {
public:
    inline __device__ Smem_utcmma_descriptor(uint64_t SWIZZLE_MODE,
                                             uint64_t BASE_OFFSET,
                                             uint64_t DESC_VERSION,
                                             uint64_t BYTES_PER_LEADING_DIM,
                                             uint64_t BYTES_PER_STRIDE_DIM) {
        // ------------------------------------
        // Setup smem descriptor for operand A:
        // ------------------------------------
        // Note 1: SWIZZLE_NONE = 0, SWIZZLE_128B = 2, SWIZZLE_64B = 4, SWIZZLE_32B = 6, SWIZZLE_128B_ATOM32B = 1, N/A = 3, N/A = 5, N/A = 7
        const uint64_t SWIZZLE_MODE_IN_BIT_LOCATION = SWIZZLE_MODE << 61;  // bits: 63-61

        // Note 2: Base offset. Valid only for matrix descriptor 1, 2 or 3, 4
        const uint64_t BASE_OFFSET_IN_BIT_LOCATION = BASE_OFFSET << 49;  // bits 51-49

        // Note 3: Descriptor version, needs to be set to 1. ???
        const uint64_t DESC_VERSION_IN_BIT_LOCATION = DESC_VERSION << 46;  // bits 48-46

        // Note 4: Stride dimension byte offset, 16 byte aligned, 4 LSBs not included
        const uint64_t STRIDE_DIM_BYTE_OFFSET                 = BYTES_PER_STRIDE_DIM >> 4;
        const uint64_t STRIDE_DIM_BYTE_OFFSET_IN_BIT_LOCATION = STRIDE_DIM_BYTE_OFFSET << 32;  // bits 45-32

        // Note 5: Leading dimension byte offset, 16 byte aligned, 4 LSBs not included
        const uint64_t LEADING_DIM_BYTE_OFFSET                 = BYTES_PER_LEADING_DIM >> 4;
        const uint64_t LEADING_DIM_BYTE_OFFSET_IN_BIT_LOCATION = LEADING_DIM_BYTE_OFFSET << 16;  // bits 29-16

        desc = (SWIZZLE_MODE_IN_BIT_LOCATION |
                BASE_OFFSET_IN_BIT_LOCATION |
                DESC_VERSION_IN_BIT_LOCATION |
                STRIDE_DIM_BYTE_OFFSET_IN_BIT_LOCATION |
                LEADING_DIM_BYTE_OFFSET_IN_BIT_LOCATION);
    }
    template<int BYTES_PER_BUFFER, int BUFFER_COUNT>
    inline __device__ void set_smem(uint32_t smem) {
        int2 &tmp = reinterpret_cast<int2 &>(desc);
        tmp.x |= (static_cast<uint64_t>(smem & 0xFFFFFF) >> 4);

        int2 &tmp_initial = reinterpret_cast<int2 &>(initial_desc);
        tmp_initial.x = tmp.x;

        max_desc = tmp.x + (BYTES_PER_BUFFER >> 4) * (BUFFER_COUNT - 1);
    }

    template<int BYTES_PER_BUFFER, int BUFFER_COUNT>
    inline __device__ void increment_smem_buffer() {
        int2 &tmp = reinterpret_cast<int2 &>(desc);
        tmp.x += (tmp.x >= max_desc) ? -(BYTES_PER_BUFFER >> 4) * (BUFFER_COUNT - 1)
                                    :  (BYTES_PER_BUFFER >> 4);
    }

    template<int BYTES_PER_BUFFER>
    inline __device__ void stage_increment_smem_buffer(int stage) {
        int2 &tmp_initial = reinterpret_cast<int2 &>(initial_desc);
        int2 &tmp = reinterpret_cast<int2 &>(desc);

        tmp.x = tmp_initial.x + (BYTES_PER_BUFFER >> 4) * stage;
    }

    template<int BYTES_OFFSET>
    inline __device__ void add_smem_offset() {
        int2 &tmp = reinterpret_cast<int2 &>(desc);
        tmp.x += (BYTES_OFFSET >> 4);
    }

    uint64_t initial_desc;
    uint64_t desc;
    uint32_t max_desc;
};
