ER_VECTOR_{guid}    = BYTES_PER_VECTOR_{guid} / BYTES_PER_ACCESS_{guid};
static constexpr int SWIZZLE_SCALE_{guid}        = BYTES_PER_ROW_{guid} / ({banks_per_swizzle_unit} * BYTES_PER_BANK_{p_id});

class Sts_tile_{guid} {{
public:
    inline __device__ Sts_tile_{guid}(uint32_t smem, int tiw, int wid) {{
        const uint32_t row_id       = (wid % 4) * ROWS_PER_WARP_{A_or_B}_{m_id} + tiw / THREADS_PER_ROW_{A_or_B}_{m_id};
        const uint32_t tma_block_id = (tiw % THREADS_PER_ROW_{A_or_B}_{m_id}) / (MAX_TMA_LOAD_ROW_BYTES_{p_id} / BYTES_PER_VECTOR_{guid});
        const uint32_t col_id       = (tiw % THREADS_PER_ROW_{A_or_B}_{m_id}) % (MAX_TMA_LOAD_ROW_BYTES_{p_id} / BYTES_PER_VECTOR_{guid});

        this->smem        = smem;
        this->smem_offset = (tma_block_id * ROWS_PER_TILE_{guid} + row_id) * BYTES_PER_ROW_{guid} + col_id * BYTES_PER_VECTOR_{guid};
    }}

    inline __device__ void store(int buffer_id, r32 reg[LDS_PER_TILE_{A_or_B}_{m_id}][LDS_PER_ROW_{A_or_B}_{m_id}][REGISTERS_PER_VECTOR_{guid}]) {{
        uint32_t smem_addr_base = smem + buffer_id * BYTES_PER_SMEM_{guid};
        #pragma unroll
        for (int i = 0; i < LDS_PER_TILE_{A_or_B}_{m_id}; ++i) {{
            #pragma unroll
            for (int j = 0; j < LDS_PER_ROW_{A_or_B}_{m_id}; ++j) {{
                #pragma unroll
                for (int k = 0; k < ACCESS_PER_VECTOR_{guid}; ++k) {{
                    uint32_t unswizzled_addr    = smem_offset + i * ROWS_PER_LOAD_{A_or_B}_{m_id} * BYTES_PER_ROW_{guid} + j * ROWS_PER_TILE_{guid} * THREADS_PER_ROW_{A_or_B}_{m_id} * BYTES_PER_VECTOR_{guid} + k * BYTES_PER_ACCESS_{guid};
                    uint32_t unswizzled_bank_id = unswizzled_addr / ({banks_per_swizzle_unit} * BYTES_PER_BANK_{p_id});
                    uint32_t swizzled_bank_id   = unswizzled_bank_id ^ ((unswizzled_bank_id / (8 / {banks_per_swizzle_unit})) % SWIZZLE_SCALE_{guid});
                    sts_{bits_per_access}(smem_addr_base + swizzled_bank_id * ({banks_per_swizzle_unit} * BYTES_PER_BANK_{p_id}) + unswizzled_addr % ({banks_per_swizzle_unit} * BYTES_PER_BANK_{p_id}), reinterpret_cast<r32*>(reinterpret_cast<char*>(reg[i][j]) + k * BYTES_PER_ACCESS_{guid}));
                }}
            }}
        }}
    }}

private:
    uint32_t smem;
    uint32_t smem_offset;
}};