).join(arg_defs)
        func_export_decl = get_export_declaration()
        inline_attr = (
            "C10_ALWAYS_INLINE_ATTRIBUTE" if config.cpp.force_inline_kernel else ""
        )
        code.writeline(
            f'extern "C" {func_export_decl} void {inline_attr} {kernel_decl_name}({arg_defs})'
        )

        # 3. Function body
        with code.indent():
            if enable_kernel_profile:
                graph_id = V.graph.graph_id
                prefix = "graph_" + str(graph_id) + "_" if graph_id is not None else ""
                code.writelines(
                    [
                        (
                            "torch::aot_inductor::RAIIAtenRecordFunctionHandle "
                            f'record_{prefix + kernel_name}_("{prefix + kernel_name}", nullptr);'
                        )
                    ]
                )
            for old, new in self.args.aliases():
                code.writeline(f"auto {old} = {new};")
            code.splice(self.loops_code)
        return code.getvalue()

    def call_kernel(self, wrapper, kernel_name, debug_handle: Optional[int] = None):
        _, call_args, arg_types = self.args.cpp_argdefs()
        wrapper.generate_kernel_call(
            kernel_name,
            call_args,
            triton=False,
            arg_types=arg_types,
            debug_handle=debug_handle,
        )


class WorkSharing:
    def __init__(self, code):
        self.code = code
        self.in_parallel = False
        self.num_threads = None
        self.stack = contextlib.ExitStack()

    def parallel(self, threads):
        if self.in_parallel and threads != self.num_threads:
            # wrong number of threads
            self.close()
        if not self.in_parallel:
            self.num_threads = threads
            self.in_parallel = True
            if config.cpp.dynamic_threads:
                self.code.writeline("#pragma omp parallel")
            else:
                self.code.writeline(f"#pragma omp parallel num_threads({threads})")
            self.stack.enter_context(self.code.indent())
            self.code.writeline(
                "int tid = omp_get_thread_num();",
            )

    def single(self):
        if self.in_parallel:
            self.code.writeline("#pragma omp single")
        return self.in_parallel

    def close(self):
        self.stack.close()
        self.in_parallel = False

    def __enter__(self):
        self.stack.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stack.__exit__(exc_type, exc_val, exc_tb)


@dataclasses.dataclass
class LoopLevel:
    var: Optional[sympy.Expr] = None
    size: Optional[sympy.Expr] = None
    offset: sympy.Expr = sympy.S.Zero
    # Note [tiled_size]
    # We may do loop-tiling at this loop level.
    # When var is in [offset, tiled_size), we will perform the vectorization kernel.
    # When var is in [tiled_size, size), we will perform the scalar or masked vectorization kernel.
    # for (var = offset; var < size; var += steps) {
    #     if (var >= offset && var < tiled_size) vec_loop_body();
    #     if (var >= tiled_size && var < size) scalar_or_maskvec_loop_body();
    # }
    tiled_size: sympy.Expr = sympy.S.Zero
    steps: sympy.Expr = sympy.S.One
    parallel: int = 0
    simd_omp: bool = False
    simd_vec: bool = False
    collapsed: bool = False
    is_reduction: bool = False

    def __post_init__(self):
        # Regarding the C++/OpenMP backend, `cpu_vec_isa.pick_vec_isa()` to check
        # vectorization ISA is a time-consuming and one-shot operation. It leads
        # to taking a longer time to import `codegen.cpp` package because the
        # `LoopLevel` of the package is decorated by `@dataclasses.dataclass` while
        # the decorator will invoke `cpu_vec_isa.pick_vec_isa()` to initialize the
        # `simd_nelements` of the `LoopLevel`. It might introduce additional compilation
        # overhead to the Triton backend. Therefore, we moved the `simd_nelements` to
        # `__post_init__`
        picked_vec_isa: cpu_vec_isa.VecISA = cpu_vec_isa.pick_vec_isa()
        self.simd_nelements: int = picked_vec_isa.nelements() if picked_vec_isa else 0

    def tile(self, factor):
        sympy_factor = sympy.Integer(factor)
        loop = LoopLevel(self.var, self.size)
        loop.steps = sympy_factor
        loop.simd_vec = True
        loop.tiled_size = FloorDiv(loop.size, sympy_factor) * sympy_factor
        loop.parallel = self.parallel
        loop.collapsed = False
        loop.is_reduction = self.is_reduction
        return loop

    def lines(self):
        offset_expr = cexpr_index(self.offset)
        size_expr = cexpr_index(self.size)
        if config.cpp.no_redundant_loops and offset_expr == size_expr:
            return None
        simd = (
            f"simd simdlen({self.simd_nelements}) "
            if self.simd_omp and self.simd_nelements > 1
            else ""
        )
        if self.parallel:
            # TODO(jansel): look into chunk size and other schedules
            line1 = "#pragma omp for"
            if self.parallel > 1:
                line1 += f" collapse({self.parallel})"
            if self.simd_omp:
                line1 = line1.replace(" for ", f" for {simd}")
        elif self.simd_vec:
            line1 = ""
        elif self.simd_omp:
            line1 = f"#pragma omp {simd}"
        elif not self.is_reduction and cpp_builder.is_gcc():
            line1 = "#pragma GCC ivdep"
        else:
            line1 = ""
        offset_str = f"{INDEX_TYPE} {self.var}={offset_expr}"
        size_str = f"{self.var}<{size_expr}"
        if self.steps.is_number:
            steps_str = f"{self.var}+={cexpr_index(self.steps)}"
        else:
            # If the step size is 0, change it to 1 because a step size of 0
            # will cause floating point exception (core dump) during parallelization.
            steps_str = (
                f"{self.var}+=({cexpr_index(self.steps)} == 0 ? "
                f"1 : {cexpr_index(self.steps)})"
            )
        line2 = f"for({offset_str}; {size_str}; {steps_str})"
        if self.collapsed or not line1:
            return [line2]
        return [line1, line2]


@dataclasses.dataclass
class LoopNest:
    """
    A loop-nest-like structure. It is built with the `build` method
    as a loop nest and then will perform loop-tiling at some depth.

    A typical case is for vectorization, where we typically do loop-tiling
    at the innermost loop level. A more complicated case is when we do
    2D tiling at both the innermost and outer levels.
    """

    loops: Optional[list[LoopLevel]] = None
    kernel: Optional[CppKernel] = None

    @staticmethod
    def build(kernel: CppKernel):
        """Build a LoopNest with the given `kernel` as the leaf"""
        itervars = kernel.itervars
        ranges = kernel.ranges
        reduction_depth = kernel.reduction_depth
        assert reduction_depth is not None

        loops: Optional[list[LoopLevel]] = None
        for loop_idx, (var, size) in enumerate(zip(itervars, ranges)):
            loop = LoopLevel(var, size)
            if not loops:
                loops = [loop]
            else:
                loops.append(loop)
            if loop_idx >= reduction_depth:
                loop.is_reduction = kernel.is_reduction

        loop_nest = LoopNest(loops)
        return loop_nest

    def __bool__(self):
        return bool(self.loops)

    @cache_on_self
    def max_parallel_depth(self):
        """
        Maximal allowed depth for parallelism: All reduction or non-reduction levels.
        When the range of the first inner loop beyond the maximum parallel depth is much
        larger than the range of all outer loops within the maximum parallel depth,
        change the starting depth of parallelism to the first inner loop and recalculate
        the maximum parallel depth.
        """
        if self.loops is None:
            return ParallelDepth(parallel_depth=0, start_depth=0)

        start_depth = 0
        max_depth = 0
        is_reduction = self.loops[0].is_reduction
        num_steps = sympy.Integer(1)
        for loop in self.loops:
            if loop.is_reduction != is_reduction:
                break
            num_steps = num_steps * FloorDiv(loop.size, loop.steps)
            max_depth += 1

        def get_simd_vec_depth(loops):
            # Return the first loop level which is simd_vec
            for i, loop in enumerate(loops):
                if loop.simd_vec:
                    return i
            return None

        simd_vec_depth = get_simd_vec_depth(self.loops)

        def has_scalar_kernel(loop_nest: LoopNest):
            assert isinstance(loop_nest.kernel, CppKernelProxy)
            return any(
                not isinstance(kernel, CppVecKernel)
                for kernel in loop_nest.kernel.kernels
            )

        # When the number of steps of the first inner loop is much larger than the number of steps of
        # all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`.
        if (
            max_depth < len(self.loops)
            and isinstance(num_steps, sympy.Integer)
            and isinstance(self.loops[max_depth].size, sympy.Integer)
            and num_steps * 300
            < FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps)
            and not (
                # Disable parallel reduction under the vec loop
                simd_vec_depth is not None
                and max_depth > simd_vec_depth
                and self.loops[max_depth].is_reduction
                and has_scalar_kernel(self)
            )
        ):
            start_depth = max_depth
            max_depth = 0
            is_reduction = self.loops[start_depth].is_reduction
            for i in range(start_depth, len(self.loops)):
                if self.loops[i].is_reduction != is_reduction:
                    break
                max_depth += 1
        return ParallelDepth(parallel_depth=max_depth, start_depth=start_depth)

    def mark_parallel(self, par_depth):
        assert par_depth.parallel_depth <= self.max_parallel_depth().parallel_depth, (
            "Parallel depth cannot exceed the maximal allowed parallel depth"
        )
        assert self.loops is not None
        assert len(self.loops) >= par_depth.parallel_depth
        loop = self.loops[par_depth.start_depth]
        loop.parallel = par_depth.parallel_depth
        if loop.is_reduction:
            metrics.parallel_reduction_count += 1
        for i in range(par_depth.start_depth + 1, par_depth.parallel_depth):
            self.loops[i].collapsed = True

    def tile(self, depth, factor):
        """
        Do loop-tiling at the `depth` level with `factor`.
            for (x0 = 0; x0 < x0_end; x0++)
            ->
            for (x0 = 0; x0 < x0_end; x0 += factor)
        See details in Note [tiled_size].
        """
        assert self.loops
        self.loops[depth] = self.loops[depth].tile(factor)
        return self.loops[depth]

    def get_kernel(self) -> CppKernel:
        assert self.kernel
        return self.kernel

    def set_kernel(self, kernel):
        self.kernel = kernel

    def from_loop_level(self, level: int):
        assert self.loops
        assert len(self.loops) >= level
        loops = None if level == len(self.loops) else self.loops[level:]
        return LoopNest(loops, self.kernel)
