# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "2026.1.3"

__all__ = [
    "SUPPORTS_BFLOAT16",
    "is_bfloat16_supported",
    "is_vLLM_available",
    "prepare_model_for_kbit_training",
    "xformers",
    "xformers_attention",
    "xformers_version",
    "__version__",
    "importlib_version",
    "HAS_FLASH_ATTENTION",
    "HAS_FLASH_ATTENTION_SOFTCAPPING",
    "USE_MODELSCOPE",
    "platform_system",
    "patch_tokenizer",
    "get_statistics",
    "Unsloth_Offloaded_Gradient_Checkpointer",
    "offload_to_disk",
    "offload_input_embeddings",
    "offload_output_embeddings",
    "unsloth_offloaded_gradient_checkpoint",
    "torch_compile_options",
    "patch_linear_scaling",
    "patch_llama_rope_scaling",
    "create_boolean_mask",
    "torch_amp_custom_fwd",
    "torch_amp_custom_bwd",
    # "accelerate_old_send_to_device",
    # "accelerate_new_send_to_device",
    "patch_gradient_accumulation_fix",
    "patch_compiling_bitsandbytes",
    "patch_regional_compilation",
    "patch_layernorm",
    "patch_torch_compile",
    "patch_model_and_tokenizer",
    "patch_unsloth_gradient_checkpointing",
    "unpatch_unsloth_gradient_checkpointing",
    "patch_gradient_checkpointing",
    "unpatch_gradient_checkpointing",
    "HAS_CUT_CROSS_ENTROPY",
    "EMPTY_LOGITS",
    "fused_linear_cross_entropy",
    "unsloth_fused_ce_loss",
    "patch_unsloth_smart_gradient_checkpointing",
    "unpatch_unsloth_smart_gradient_checkpointing",
    "patch_compiled_autograd",
    "process_vision_info",
    "unsloth_compile_transformers",
    "patch_fast_lora",
    "validate_loftq_config",
    "RaiseUninitialized",
    "fast_inference_setup",
    "patch_peft_fast_inference",
    "error_out_no_vllm",
    "dequantize_module_weight",
    "patch_hf_quantizer",
    "verify_fp8_support_if_applicable",
    "_get_inference_mode_context_manager",
    "hf_login",
    "make_fast_generate_wrapper",
]

import torch
from typing import Union, Optional, List, Any, Callable, Tuple, Iterator
from platform import system as platform_system

platform_system = platform_system()
import numpy as np
import contextlib
import re
from dataclasses import dataclass, field
import functools
import textwrap
import logging
import warnings, subprocess, inspect, psutil, os, math
from unsloth_zoo.utils import Version, get_quant_type
from importlib.metadata import version as importlib_version
from ..device_type import (
    is_hip,
    get_device_type,
    DEVICE_TYPE,
    DEVICE_TYPE_TORCH,
    DEVICE_COUNT,
    ALLOW_PREQUANTIZED_MODELS,
)
from unsloth_zoo.log import logger
from unsloth_zoo.tokenizer_utils import (
    patch_tokenizer as _patch_tokenizer,
)
from unsloth_zoo.rl_environments import (
    check_python_modules,
    create_locked_down_function,
    execute_with_time_limit,
    Benchmarker,
)
from unsloth_zoo.patching_utils import (
    patch_compiling_bitsandbytes,
    patch_layernorm,
    patch_torch_compile,
    patch_model_and_tokenizer,
    patch_compiled_autograd,
)
from unsloth_zoo.gradient_checkpointing import (
    Unsloth_Offloaded_Gradient_Checkpointer,
    unsloth_offloaded_gradient_checkpoint,
    patch_unsloth_gradient_checkpointing,
    unpatch_unsloth_gradient_checkpointing,
    Unsloth_Gradient_Checkpointer,
    unsloth_gradient_checkpoint,
    patch_gradient_checkpointing,
    unpatch_gradient_checkpointing,
    patch_unsloth_smart_gradient_checkpointing,
    unpatch_unsloth_smart_gradient_checkpointing,
)
from unsloth_zoo.loss_utils import (
    HAS_CUT_CROSS_ENTROPY,
    fused_linear_cross_entropy,
    _unsloth_get_batch_samples,
    unsloth_fused_ce_loss,
)
from unsloth_zoo.vision_utils import (
    process_vision_info,
)
from unsloth_zoo.compiler import (
    get_transformers_model_type,
    unsloth_compile_transformers as _unsloth_compile_transformers,
)
from unsloth_zoo.training_utils import (
    prepare_model_for_training,
)
from unsloth_zoo.temporary_patches import (
    TEMPORARY_PATCHES,
)

for temporary_patch in TEMPORARY_PATCHES:
    temporary_patch()

# =============================================
# Disable some warnings which can get annoying
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "torch")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "huggingface_hub")
warnings.filterwarnings(
    action = "ignore", category = FutureWarning, module = "huggingface_hub"
)
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "trl")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "xformers")
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "subprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "transformers")
warnings.filterwarnings(action = "ignore", category = FutureWarning, module = "accelerate")
warnings.filterwarnings(
    action = "ignore", category = RuntimeWarning, module = "multiprocessing"
)
warnings.filterwarnings(action = "ignore", category = RuntimeWarning, module = "multiprocess")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "triton")
warnings.filterwarnings(action = "ignore", category = UserWarning, module = "bitsandbytes")

# Stop "Special tokens have been added in the vocabulary, ..."
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.CRITICAL + 1)

TORCHAO_MSG = "Error: torchao not found, please install with `pip install torchao`"


# Ignore logging messages
class HideLoggingMessage(logging.Filter):
    __slots__ = ("text",)

    def __init__(self, text):
        self.text = text

    def filter(self, x):
        return not (self.text in x.getMessage())


# Stop vLLM messages
if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") != "1":
    try:
        from vllm.worker.worker import logger as vllm_worker_logger

        vllm_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
        del vllm_worker_logger
    except:
        pass
    try:
        from vllm.v1.worker.gpu_worker import logger as vllm_gpu_worker_logger

        vllm_gpu_worker_logger.addFilter(HideLoggingMessage("Sleep mode freed"))
        del vllm_gpu_worker_logger
    except:
        pass
    try:
        from vllm.executor.executor_base import logger as vllm_executor_logger

        vllm_executor_logger.addFilter(HideLoggingMessage("to fall asleep"))
        vllm_executor_logger.addFilter(HideLoggingMessage("to wake up"))
        vllm_executor_logger.addFilter(HideLoggingMessage("Executor is not sleeping"))
        del vllm_executor_logger
    except:
        pass
    try:
        from vllm.core.block.prefix_caching_block import (
            logger as vllm_prefix_caching_logger,
        )

        vllm_prefix_caching_logger.addFilter(HideLoggingMessage("reset prefix cache"))
        del vllm_prefix_caching_logger
    except:
        pass
    try:
        from vllm.v1.core.block_pool import logger as vllm_block_pool_logger

        vllm_block_pool_logger.addFilter(HideLoggingMessage("reset prefix cache"))
        del vllm_block_pool_logger
    except:
        pass
    try:
        from vllm.lora.models import logger as vllm_lora_model_logger

        vllm_lora_model_logger.addFilter(
            HideLoggingMessage(
                "Regarding multimodal models, vLLM currently only supports adding"
            )
        )
        del vllm_lora_model_logger
    except:
        pass
    try:
        from vllm.attention.utils.fa_utils import (
            logger as vllm_attention_utils_fa_utils_logger,
        )

        vllm_attention_utils_fa_utils_logger.addFilter(
            HideLoggingMessage("Cannot use FA version")
        )
        del vllm_attention_utils_fa_utils_logger
    except:
        pass

# The speedups for torchdynamo mostly come with GPU Ampere or higher and which is not detected here.
from transformers.training_args import logger as transformers_training_args_logger

transformers_training_args_logger.addFilter(HideLoggingMessage("The speedups"))
# torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED.
transformers_training_args_logger.addFilter(HideLoggingMessage("torch.distributed"))
# average_tokens_across_devices is set to True but it is invalid when world size is1
transformers_training_args_logger.addFilter(
    HideLoggingMessage("average_tokens_across_devices")
)
del transformers_training_args_logger

# No label_names provided for model class
from transformers.trainer import logger as transformers_trainer_logger

transformers_trainer_logger.addFilter(HideLoggingMessage("No label_names"))

# The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config.
transformers_trainer_logger.addFilter(HideLoggingMessage("The tokenizer has new"))
del transformers_trainer_logger

# Using the default loss: `ForCausalLMLoss`.
try:
    from transformers.modeling_utils import logger as transformers_modeling_utils_logger

    transformers_modeling_utils_logger.addFilter(HideLoggingMessage("ForCausalLMLoss"))
    del transformers_modeling_utils_logger
except:
    pass

# The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
try:
    from accelerate.utils.modeling import logger as accelerate_utils_modeling_logger

    accelerate_utils_modeling_logger.addFilter(
        HideLoggingMessage("The model weights are not tied")
    )
    del accelerate_utils_modeling_logger
except:
    pass

# Setting `pad_token_id` to `eos_token_id`
try:
    from transformers.generation.utils import (
        logger as transformers_generation_utils_logger,
    )

    transformers_generation_utils_logger.addFilter(
        HideLoggingMessage("Setting `pad_token_id` to `eos_token_id`")
    )
    # "You have set `compile_config`
    transformers_generation_utils_logger.addFilter(HideLoggingMessage("compile_config"))
    del transformers_generation_utils_logger
except:
    pass

# The following generation flags are not valid and may be ignored:
try:
    from transformers.generation.configuration_utils import (
        logger as configuration_logger,
    )

    configuration_logger.addFilter(HideLoggingMessage("following generation flags"))
    del configuration_logger
except:
    pass

# Gemma3 It is strongly recommended to train Gemma3 models with the `eager`
try:
    from transformers.models.gemma3.modeling_gemma3 import logger as gemma3_logger

    gemma3_logger.addFilter(HideLoggingMessage("strongly recommended"))
    del gemma3_logger
except:
    pass

# Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed.
try:
    from huggingface_hub.file_download import logger as hub_logger

    hub_logger.addFilter(HideLoggingMessage("hf_xet"))
    del hub_logger
except:
    pass

# MXFP4 quantization requires triton >= 3.4.0
try:
    from transformers.quantizers.quantizer_mxfp4 import logger as mxfp4_logger

    mxfp4_logger.addFilter(HideLoggingMessage("requires triton"))
    del mxfp4_logger
except:
    pass

# You passed `quantization_config` or equivalent parameters
try:
    warnings.filterwarnings(
        action = "ignore",
        message = r".*quantization_config.*",
        category = UserWarning,
        append = True,
    )
except:
    pass

# UserWarning: Logical operators 'and' and 'or' are deprecated for non-scalar tensors; please use '&' or '|' instead
# Will be fixed in torch 2.8.1 https://github.com/pytorch/pytorch/issues/158463
try:
    warnings.filterwarnings(
        action = "ignore",
        message = r".*Logical operators 'and' and 'or'.*",
        category = UserWarning,
        append = True,
    )
except:
    pass

# Using a slow image processor as `use_fast`
try:
    from transformers.processing_utils import logger as processing_utils_logger

    processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
    del processing_utils_logger
except:
    pass

# Using a slow image processor as `use_fast`
try:
    from transformers.models.auto.image_processing_auto import (
        logger as processing_utils_logger,
    )

    processing_utils_logger.addFilter(HideLoggingMessage("`use_fast`"))
    del processing_utils_logger
except:
    pass

# `use_cache=True` is incompatible with gradient checkpointing
try:
    from transformers.trainer import logger as trainer_logger

    trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
    del trainer_logger
except:
    pass

# `use_cache=True` is incompatible with gradient checkpointing
try:
    from transformers.utils.generic import logger as trainer_logger

    trainer_logger.addFilter(HideLoggingMessage("`use_cache=True`"))
    del trainer_logger
except:
    pass

# We detected that you are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')
try:
    from transformers.modeling_utils import logger as modeling_utils_logger

    modeling_utils_logger.addFilter(HideLoggingMessage("anti-pattern"))
    del modeling_utils_logger
except:
    pass

# Errors out on
# Some weights of Gemma3nForConditionalGeneration were not initialized from the model checkpoint
from transformers.modeling_utils import logger as transformers_logger


class _RaiseUninitialized(logging.Handler):
    def __init__(self):
        super().__init__()

    def emit(self, record):
        record_lower = str(record).lower()
        if (
            ("some weights of" in record_lower)
            and ("score.weight" not in record_lower)
            and ("classifier.weight" not in record_lower)
            and ("cls.predictions" not in record_lower)
            and ("predictions.decoder" not in record_lower)
            and (os.environ.get("UNSLOTH_WARN_UNINITIALIZED", "1") == "1")
        ):
            raise Exception(
                f"Unsloth: Critical error since some weights are not initialized.\n"
                f"Please try updating Unsloth, transformers and timm via:\n"
                f"`pip install --upgrade --force-reinstall --no-cache-dir --no-deps unsloth unsloth_zoo transformers timm`\n"
                f"{str(record)}"
            )


class RaiseUninitialized:
    def __init__(self):
        self.error_handler = _RaiseUninitialized()
        transformers_logger.addHandler(self.error_handler)

    def remove(self):
        transformers_logger.removeHandler(self.error_handler)


# Patch get_model_param_count to record correct 4bit / 8bit
from transformers.trainer_pt_utils import is_deepspeed_zero3_enabled


def extract_quant_model_param_count(model):
    """
    Calculate quant model param count based on difference in param class. Returns int for param count.
    """
    count: int = 0
    for name, p in model.named_parameters():
        if p.__class__.__name__ == "Params4bit":
            count += 2 * p.numel()
        else:
            count += p.numel()
    return count


def get_model_param_count(model, trainable_only = False):
    """
    Calculate model's total param count. If trainable_only is True then count only those requiring grads
    """
    if is_deepspeed_zero3_enabled():

        def numel(p):
            return p.ds_numel if hasattr(p, "ds_numel") else p.numel()
    else:

        def numel(p):
            return p.numel()

    s = sum(
        numel(p) for p in model.parameters() if not trainable_only or p.requires_grad
    )
    if (
        (not trainable_only)
        and hasattr(model, "config")
        and hasattr(model.config, "quantization_config")
    ):
        approx = extract_quant_model_param_count(model)
        if approx is not None:
            s = approx
    return s


import transformers.trainer_pt_utils

transformers.trainer_pt_utils.get_model_param_count = get_model_param_count
import transformers.trainer

transformers.trainer.get_model_param_count = get_model_param_count
# =============================================

# =============================================
# Edits all Config files to enable RoPE Scaling for all models


# Transformers had to update for Mistral Nemo 12b since Attention is (5120, 4096) now.
def patch_mistral_nemo_config(config):
    if "head_dim (" not in config:
        add_head_dim = (
            "If it is not specified, will default to `8`.\n"
            "        head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):\n"
            "            The attention head dimension."
        )
        config = config.replace(
            "If it is not specified, will default to `8`.", add_head_dim
        )

        add_head_dim = "num_key_value_heads=8,\n        head_dim=None,"
        config = config.replace("num_key_value_heads=8,", add_head_dim)

        add_head_dim = "self.sliding_window = sliding_window\n        self.head_dim = head_dim or hidden_size // num_attention_heads\n"
        config = config.replace("self.sliding_window = sliding_window", add_head_dim)
    return config


try:
    # Some Config files use layer_type_validation
    # for eg Gemma-2, so we must import it to stop errors.
    from transformers.configuration_utils import layer_type_validation
except:
    pass
from transformers import __version__ as transformers_version

try:
    from transformers import PreTrainedConfig
except:
    from transformers import PretrainedConfig

model_architectures = [
    "llama",
    "mistral",
    "gemma",
    "gemma2",
    "qwen2",
    "granite",
    "qwen3",
    "qwen3_moe",
    "falcon_h1",
]

for model_name in model_architectures:
    config_filepath = f"transformers.models.{model_name}.configuration_{model_name}"
    model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
    config_filename = f"{model_name.title().replace('_','')}Config"  # qwen3 arch folder is qwen3_moe but config is Qwen3Config. Need to remove underscore(_) for now
    try:
        exec(f"from {config_filepath} import {config_filename}", globals())
    except:
        continue

    try:
        config = inspect.getsource(eval(config_filename))
    except:
        continue
    if "RopeParameters" in config:
        try:
            exec(f"from {config_filepath} import RopeParameters", globals())
        except:
            continue

    if "rope_scaling" in config:
        continue
    config = re.sub(
        r"(\*\*kwargs)[\s]{0,}\,[\s]{0,}\)[\s]{0,}\:",
        r"rope_scaling=None,"
        r"\n        **kwargs):\n"
        r"\n        self.rope_scaling = rope_scaling\n",
        config,
    )

    # Just for Mistral Nemo
    if model_name == "mistral":
        if Version(transformers_version) <= Version("4.42.4"):
            config = patch_mistral_nemo_config(config)

    exec(config, globals())
    exec(f"import {config_filepath}", globals())
    exec(f"{config_filepath}.{config_filename} = {config_filename}", globals())
# =============================================

# =============================================
# torch.cuda.amp.custom_fwd is deprecated >= 2.4
torch_version = torch.__version__
if DEVICE_TYPE in ("cuda", "hip"):
    if Version(torch_version) < Version("2.4.0"):
        torch_amp_custom_fwd = torch.cuda.amp.custom_fwd
        torch_amp_custom_bwd = torch.cuda.amp.custom_bwd
    else:
        torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "cuda")
        torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "cuda")
elif DEVICE_TYPE == "xpu":
    if Version(torch_version) < Version("2.6.0"):
        raise RuntimeError("torch.xpu currently only supports torch.version >= 2.6.0")
    else:
        torch_amp_custom_fwd = torch.amp.custom_fwd(device_type = "xpu")
        torch_amp_custom_bwd = torch.amp.custom_bwd(device_type = "xpu")
# =============================================

# =============================================
# Fix KeyError: 'Cache only has 0 layers, attempted to access layer with index 0'
# import transformers.cache_utils
# if hasattr(transformers.cache_utils, "DynamicCache") and \
#     transformers.cache_utils.DynamicCache.__getitem__.__name__ != "__cache_utils_getitem__":

#     source = inspect.getsource(transformers.cache_utils.DynamicCache.__getitem__)
#     start = source.find("def")
#     spaces = start*" "
#     source = source.split("\n")
#     source = "\n".join(x[start:] for x in source)
#     where = source.find("raise KeyError")
#     source = source[:where] + \
#         f"if len(self) == 0:\n{spaces}{spaces}"\
#         "    raise RuntimeError('Unsloth: You must call `FastLanguageModel.for_inference(model)` before doing inference for Unsloth models.')\n" + \
#         f"{spaces}{spaces}else:\n{spaces}{spaces}{spaces}" + source[where:]
#     source = source.replace("__getitem__", "__cache_utils_getitem__", 1)
#     exec(source)
#     transformers.cache_utils.DynamicCache.__getitem__ = __cache_utils_getitem__
# pass
# =============================================

# =============================================
# Weird Databricks errors
from transformers.utils import is_openai_available

if is_openai_available():
    try:
        from openai import OpenAI
    except:
        print("Unsloth: OpenAI failed to import - ignoring for now.")
        import transformers.utils

        def _is_openai_available():
            return False

        transformers.utils.is_openai_available = _is_openai_available

# =============================================
# Get Flash Attention v2 if Ampere (RTX 30xx, A100)
import bitsandbytes as bnb

from transformers import AutoTokenizer
from transformers.utils.import_utils import _is_package_available

SUPPORTS_BFLOAT16 = False
HAS_FLASH_ATTENTION = False
HAS_FLASH_ATTENTION_SOFTCAPPING = False

if DEVICE_TYPE == "cuda":
    major_version, minor_version = torch.cuda.get_device_capability()
    torch.cuda.get_device_capability = functools.cache(torch.cuda.get_device_capability)

    if major_version >= 8:
        SUPPORTS_BFLOAT16 = True
        if _is_package_available("flash_attn"):
            # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
            try:
                try:
                    # See https://github.com/unslothai/unsloth/issues/1437
                    from flash_attn.flash_attn_interface import flash_attn_gpu
                except:
                    from flash_attn.flash_attn_interface import flash_attn_cuda
                HAS_FLASH_ATTENTION = True

                # Also check for softcapping
                from flash_attn import __version__ as flash_attn_version

                HAS_FLASH_ATTENTION_SOFTCAPPING = Version(
                    flash_attn_version
                ) >= Version("2.6.3")
                if not HAS_FLASH_ATTENTION_SOFTCAPPING:
                    print(
                        "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
                        "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
                        "To update flash-attn, do the below:\n"
                        '\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
                    )
            except:
                print(
                    "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"
                    "A possible explanation is you have a new CUDA version which isn't\n"
                    "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"
                    "We shall now use Xformers instead, which does not have any performance hits!\n"
                    "We found this negligible impact by benchmarking on 1x A100."
                )

                # Stop Flash Attention from importing!
                import transformers.utils.import_utils

                transformers.utils.import_utils.is_flash_attn_2_available = (
                    lambda *args, **kwargs: False
                )
                import transformers.utils

                transformers.utils.is_flash_attn_2_available = (
                    lambda *args, **kwargs: False
                )

                HAS_FLASH_ATTENTION = False
        else:
            HAS_FLASH_ATTENTION = False
    else:
        # Tri Dao's benchmark shows xformers is faster for now.
        HAS_FLASH_ATTENTION = False
elif DEVICE_TYPE == "hip":
    SUPPORTS_BFLOAT16 = True
    if _is_package_available("flash_attn"):
        # Check for CUDA linking errors "undefined symbol: _ZNK3c106SymIntltEl"
        try:
            try:
                # See https://github.com/unslothai/unsloth/issues/1437
                from flash_attn.flash_attn_interface import flash_attn_gpu
            except:
                from flash_attn.flash_attn_interface import flash_attn_cuda
            HAS_FLASH_ATTENTION = True

            # Also check for softcapping
            from flash_attn import __version__ as flash_attn_version

            HAS_FLASH_ATTENTION_SOFTCAPPING = Version(flash_attn_version) >= Version(
                "2.6.3"
            )
            if not HAS_FLASH_ATTENTION_SOFTCAPPING:
                print(
                    "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
                    "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
                    "To update flash-attn, do the below:\n"
                    '\npip install --no-deps --no-build-isolation --upgrade "flash-attn>=2.6.3"'
                )
        except:
            print(
                "Unsloth: Your Flash Attention 2 installation seems to be broken?\n"
                "A possible explanation is you have a new CUDA version which isn't\n"
                "yet compatible with FA2? Please file a ticket to Unsloth or FA2.\n"
                "We shall now use Xformers instead, which does not have any performance hits!\n"
                "We found this negligible impact by benchmarking on 1x A100."
            )

            # Stop Flash Attention from importing!
            import transformers.utils.import_utils

            transformers.utils.import_utils.is_flash_attn_2_available = (
                lambda *args, **kwargs: False
            )
            import transformers.utils

            transformers.utils.is_flash_attn_2_available = lambda *args, **kwargs: False

            HAS_FLASH_ATTENTION = False
elif DEVICE_TYPE == "xpu":
    SUPPORTS_BFLOAT16 = True

# =============================================
# Get Xformers
# Silence xformers CUDA mismatch warnings before import
try:
    _xformers_logger = logging.getLogger("xformers")
    _xformers_logger.setLevel(logging.ERROR)
    del _xformers_logger
except:
    pass
try:
    from xformers import __version__ as xformers_version

    # [TODO] Xformers does NOT work on RTX 50x (12), B200 (10), Jetson (11)
    # See https://github.com/facebookresearch/xformers/issues/1329
    # CUDA error (/workspace/xfrm2/third_party/flash-attention/hopper/flash_fwd_launch_template.h:188)
    major_version, minor_version = torch.cuda.get_device_capability()
    if (f"{major_version}.{minor_version}" in ("10.0", "11.0", "12.0")) and (
        Version(xformers_version) in (Version("0.0.32.post2"),)
    ):
        raise NotImplementedError(
            "Unsloth: Xformers does not work in RTX 50X, Blackwell GPUs as of yet. Please build from source via\n"
            "```\n"
            "pip install ninja\n"
            "pip install -v --no-build-isolation -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers\n"
            "```\n"
        )
    # Temporarily disable 0.0.27 and higher - inference issues
    if False:  # Version(xformers_version) >= Version("0.0.27"):
        raise ImportError(
            "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
            "then press Disconnect Runtime and then Restart it.\n"
            "\n"
            "%%capture\n"
            "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
            '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
            '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
            "\n"
            f"Otherwise in local machines, your xformers version of {xformers_version} is too new.\n"
            'Please downgrade xformers via `pip install --force-reinstall "xformers<=0.0.27"'
        )

    if Version(torch_version) < Version("2.2.0") and Version(
        xformers_version
    ) >= Version("0.0.24"):
        raise ImportError(
            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
            f"Please install xformers < 0.0.24 for torch = {torch_version}."
        )
    elif Version(torch_version) < Version("2.3.0") and Version(
        xformers_version
    ) >= Version("0.0.26"):
        raise ImportError(
            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
            f"Please install xformers < 0.0.26 for torch = {torch_version}."
        )
    elif Version(torch_version) < Version("2.4.0") and Version(
        xformers_version
    ) > Version("0.0.27"):
        raise ImportError(
            f"Unsloth: You have torch = {torch_version} but xformers = {xformers_version}.\n"
            f"Please install xformers <= 0.0.27 for torch = {torch_version}."
        )

    from xformers._cpp_lib import _register_extensions

    try:
        _register_extensions()  # Check if C++ modules are loaded correctly
    except Exception as error:
        raise ImportError(
            "Unsloth: Xformers was not installed correctly.\n"
            "Please install xformers separately first.\n"
            "Then confirm if it's correctly installed by running:\n"
            "python -m xformers.info\n\n"
            "Longer error message:\n" + str(error)
        )
    import xformers.ops.fmha as xformers

    xformers_attention = xformers.memory_efficient_attention
except ModuleNotFoundError:
    xformers = None
    xformers_attention = None
    xformers_version = None
except Exception as e:
    if os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") != "0":
        print(
            "========\nSwitching to PyTorch attention since your Xformers is broken.\n========\n"
        )
        print(str(e))
    xformers = None
    xformers_attention = None
    xformers_version = None

# Check TRL version
from trl import __version__ as trl_version

# Unsloth now supports all TRL versions!
if False:  # Version(trl_version) >= Version("0.9.0"):
    raise ImportError(
        "Unsloth: If you are in Colab, we updated the top cell install instructions - please change it to below "
        "then press Disconnect Runtime and then Restart it.\n"
        "\n"
        "%%capture\n"
        "# Installs Unsloth, Xformers (Flash Attention) and all other packages!\n"
        '!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"\n'
        '!pip install --no-deps "xformers<=0.0.27" trl peft accelerate bitsandbytes\n'
        "\n"
        f"Otherwise in local machines, your TRL version of {trl_version} is too new.\n"
        "Please downgrade TRL via `pip install --force-reinstall trl"
    )

# =============================================
# Fix new Xformers versions TypeError: Multiple dispatch failed for 'torch._ops.aten.to.dtype_layout'
# accelerate_old_send_to_device = None
# accelerate_new_send_to_device = None
# if xformers_version is not None and Version(xformers_version) >= Version("0.0.27"):
#     import accelerate.utils.operations
#     if hasattr(accelerate.utils.operations, "send_to_device") and \
#         accelerate.utils.operations.send_to_device.__name__ != "_fixed_send_to_device":
#         accelerate_old_send_to_device = accelerate.utils.operations.send_to_device
#         from accelerate.utils.operations import *
#         send_to_device = inspect.getsource(accelerate.utils.operations.send_to_device)
#         send_to_device = re.sub(
#             r"([ ]{4,})return tensor\.to\(device\)",
#             r"\1try: return tensor.to(device)\n\1except: return tensor",
#             send_to_device,
#         ).replace("def send_to_device", "def _fixed_send_to_device")
#         exec(send_to_device)
#         # accelerate.utils.operations.send_to_device = _fixed_send_to_device
#         accelerate_new_send_to_device = _fixed_send_to_device
#     pass
# pass

# Transformers 4.46 breaks dynamic caching. This is a hack
import transformers.generation.configuration_utils

if hasattr(transformers.generation.configuration_utils, "ALL_CACHE_IMPLEMENTATIONS"):
    if (
        type(transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS)
        is list
    ):
        if (
            "dynamic"
            not in transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS
        ):
            transformers.generation.configuration_utils.ALL_CACHE_IMPLEMENTATIONS.append(
                "dynamic"
            )
# =============================================

# =============================================
# Torch compile settings
UNSLOTH_COMPILE_DEBUG = os.environ.get("UNSLOTH_COMPILE_DEBUG", "0") == "1"
UNSLOTH_COMPILE_MAXIMUM = os.environ.get("UNSLOTH_COMPILE_MAXIMUM", "0") == "1"
UNSLOTH_COMPILE_IGNORE_ERRORS = (
    os.environ.get("UNSLOTH_COMPILE_IGNORE_ERRORS", "1") == "1"
)
# Just remove max_autotune_gemm warning
from torch._inductor.runtime.hints import DeviceProperties


@functools.lru_cache(None)
def is_big_gpu(index) -> bool:
    if DEVICE_TYPE == "xpu":
        prop = DeviceProperties.create(
            torch.device("xpu", index) if type(index) is int else index
        )
        min_sms = 16
    else:
        prop = DeviceProperties.create(
            torch.device("cuda", index) if type(index) is int else index
        )
        min_sms = 80

    avail_sms = prop.multi_processor_count
    if avail_sms < min_sms:
        return False
    return True


import torch._inductor.utils

torch._inductor.utils.is_big_gpu = is_big_gpu
patch_torch_compile(
    debug = UNSLOTH_COMPILE_DEBUG,
    O3 = UNSLOTH_COMPILE_MAXIMUM,
    ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS,
)

torch_compile_options = {
    "epilogue_fusion": True,
    "max_autotune": True,
    "shape_padding": True,
    "trace.enabled": UNSLOTH_COMPILE_DEBUG,
    "triton.cudagraphs": False,
}

import accelerate


def torch_compile_kwargs(*args, **kwargs):
    print("Unsloth: Enabled auto compiling")
    return {
        "dynamic": True,
        "fullgraph": False,
        "options": torch_compile_options,
    }


accelerate.utils.dataclasses.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
accelerate.utils.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
accelerate.accelerator.TorchDynamoPlugin.to_kwargs = torch_compile_kwargs
del accelerate


def patch_regional_compilation():
    # Regional torch 2.5 Recompilation - weirdly very slow??
    if torch.nn.ModuleList.__name__ == "UnslothModuleList":
        return
    # Only works for torch 2.5
    if Version(torch.__version__) < Version("2.5.0"):
        return

    old_module_list = torch.nn.ModuleList
    os.environ["UNSLOTH_PATCHED"] = "1"

    def UnslothModuleList(*args, **kwargs):
        if len(args) == 1 and len(kwargs) == 0 and type(args[0]) is list:
            args = [
                old_module_list(
                    [
                        torch.compile(
                            x,
                            dynamic = True,
                            options = torch_compile_options,
                            fullgraph = False,
                        )
                        for x in args[0]
                    ]
                )
            ]
        return old_module_list(*args, **kwargs)

    UnslothModuleList.__doc__ = old_module_list.__doc__

    torch.nn.ModuleList = UnslothModuleList
    return


# =============================================


def prepare_model_for_kbit_training(
    model: Any,
    use_gradient_checkpointing: Optional = True,
    use_reentrant: Optional[bool] = True,
) -> Any:
    return prepare_model_for_training(
        model = model,
        use_gradient_checkpointing = use_gradient_checkpointing,
        use_reentrant = use_reentrant,
        full_finetuning = False,
        train_layernorms = False,
        train_embedding = False,
        train_lm_head = False,
        float32_mixed_precision = True,
    )


# =============================================
# Weirdly LoraLayer.update_layer downcasts PEFT layers to float16??
# For mixed precision, we need it to be in float32 not float16.
from peft import __version__ as peft_version
from peft.utils.integrations import dequantize_module_weight

if Version(peft_version) < Version("0.12.0"):
    from peft.tuners.lora.layer import LoraLayer

    try:
        source = inspect.getsource(LoraLayer.update_layer)
        text = "if weight is not None:\n"
        start = source.find(text) + len(text)
        end = source.find("self.to(weight.device)", start)
        spaces = re.findall(r"^([ ]{1,})break", source, flags = re.MULTILINE)[0]
        source = source.replace(source[start:end], spaces)
        spaces = len(re.match(r"[\s]{1,}", source).group(0))
        lines = source.split("\n")
        source = "\n".join(x[spaces:] for x in lines)
        source = re.sub(r"([^\.])nn\.", r"\1torch.nn.", source)
        source = source.replace("def update_layer", "def LoraLayer_update_layer")
        exec(source, globals())

        # Fix up incorrect downcasting of LoRA weights
        from peft.tuners.lora.layer import LoraLayer

        LoraLayer.update_layer = LoraLayer_update_layer
        from peft.tuners.lora import LoraLayer

        LoraLayer.update_layer = LoraLayer_update_layer
    except:
        logger.warning_once(
            "Unsloth unsuccessfully patched LoraLayer.update_layer. Please file a bug report.\n"
            "Luckily, your training run will still work in the meantime!"
        )

# =============================================
import importlib

global USE_MODELSCOPE
USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
if USE_MODELSCOPE:
    if importlib.util.find_spec("modelscope") is None:
        raise ImportError(
            f"You are using the modelscope hub, please install modelscope by `pip install modelscope -U`"
        )

import socket


@functools.lru_cache(1)
def has_internet(host = "8.8.8.8", port = 53, timeout = 3):
    if os.environ.get("TRANSFORMERS_OFFLINE", "0") == "1":
        return False
    try:
        socket.setdefaulttimeout(timeout)
        socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
        return True
    except socket.error as ex:
        return False


import psutil


def _get_statistics(statistics = None, force_download = True):
    # We log some basic stats about which environment is being used.
    # We simply download a README.md file from HF - all data is made public.
    # This is simply so we can check if some envs are broken or not.
    # You can disable this by commenting the below out
    n_cpus = psutil.cpu_count(logical = False)
    keynames = "\n" + "\n".join(os.environ.keys())
    # Check modelscope for down detection
    global USE_MODELSCOPE
    USE_MODELSCOPE = os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"

    if statistics is None:
        # Prefer filesystem markers (harder to misidentify) before env-key matching
        try:
            from pathlib import Path

            if Path("/kaggle/working").exists():
                statistics = "kaggle"
            elif Path("/content").exists() and Path("/opt/colab").exists():
                statistics = "colab" if n_cpus == 1 else "colabpro"
            elif Path("/runpod-volume").exists():
                statistics = "runpod"
        except Exception:
            pass

        # Fallback to env-key detection
        if statistics is None:
            if "\nKAGGLE_" in keynames:
                statistics = "kaggle"
            elif "\nCOLAB_" in keynames and n_cpus == 1:
                statistics = "colab"
            elif "\nCOLAB_" in keynames:
                statistics = "colabpro"
            elif "\nRUNPOD_" in keynames:
                statistics = "runpod"
            elif "\nAWS_" in keynames:
                statistics = "aws"
            elif "\nAZURE_" in keynames:
                statistics = "azure"
            # elif "\nK_" in keynames or "\nFUNCTION_" in keynames: statistics = "gcp"
            elif "\nINVOCATION_ID" in keynames:
                statistics = "lambda"
            # else: statistics = "other"
            else:

                def try_vllm_check():
                    vendor_files = (
                        "/sys/class/dmi/id/product_version",
                        "/sys/class/dmi/id/bios_vendor",
                        "/sys/class/dmi/id/product_name",
                        "/sys/class/dmi/id/chassis_asset_tag",
                        "/sys/class/dmi/id/sys_vendor",
                    )

                    for vendor_file in vendor_files:
                        path = Path(vendor_file)
                        if path.is_file():
                            file_content = path.read_text().lower()
                            if "amazon" in file_content:
                                return "aws"
                            elif "microsoft corporation" in file_content:
                                return "azure"
                            elif "google" in file_content:
                                return "gcp"
                    return "other"

                try:
                    statistics = try_vllm_check()
                except Exception:
                    statistics = "other"

    if statistics is not None:
        import tempfile
        from huggingface_hub import snapshot_download
        from unsloth_zoo.rl_environments import execute_with_time_limit

        if has_internet():

            def stats_check():
                with tempfile.TemporaryDirectory(ignore_cleanup_errors = True) as f:
                    snapshot_download(
                        f"unslothai/{statistics}",
                        force_download = True,
                        cache_dir = f,
                        local_dir = f,
                    )

            time_limited_stats_check = execute_with_time_limit(120)(stats_check)
            try:
                time_limited_stats_check()
            except TimeoutError:
                raise TimeoutError(
                    "Unsloth: HuggingFace seems to be down after trying for 120 seconds :(\n"
                    "Check https://status.huggingface.co/ for more details.\n"
                    "As a temporary measure, use modelscope with the same model name ie:\n"
                    "```\n"
                    "pip install modelscope\n"
                    "import os; os.environ['UNSLOTH_USE_MODELSCOPE'] = '1'\n"
                    "from unsloth import FastLanguageModel\n"
                    "model = FastLanguageModel.from_pretrained('unsloth/gpt-oss-20b')\n"
                    "```"
                )
            except Exception:
                # Try no time limit check
                stats_check()


def get_statistics(local_files_only = False):
    # We log some basic stats about which environment is being used.
    # This is also to check if HuggingFace is down or not!
    # We simply download a README.md file from HF - all data is made public.
    # This is simply so we can check if some envs are broken or not.
    # You can disable this by setting UNSLOTH_DISABLE_STATISTICS
    import os

    if (
        "UNSLOTH_DISABLE_STATISTICS" in os.environ
        or os.environ.get("UNSLOTH_USE_MODELSCOPE", "0") == "1"
    ):
        return
    if local_files_only:
        return
    from huggingface_hub.utils import (
        disable_progress_bars,
        enable_progress_bars,
        are_progress_bars_disabled,
    )

    disabled = False
    if not are_progress_bars_disabled():
        disable_progress_bars()
        disabled = True
    _get_statistics(None)
    _get_statistics("repeat", force_download = False)
    total_memory = (
        torch.xpu.get_device_properties(0).total_memory
        if DEVICE_TYPE == "xpu"
        else torch.cuda.get_device_properties(0).total_memory
    )
    vram = total_memory / 1024 / 1024 / 1024
    if vram <= 8:
        vram = 8
    elif vram <= 16:
        vram = 16
    elif vram <= 20:
        vram = 20
    elif vram <= 24:
        vram = 24
    elif vram <= 40:
        vram = 40
    elif vram <= 48:
        vram = 48
    elif vram <= 80:
        vram = 80
    else:
        vram = 96
    _get_statistics(f"vram-{vram}")
    _get_statistics(f"{DEVICE_COUNT if DEVICE_COUNT <= 8 else 9}")
    if disabled:
        enable_progress_bars()


# =============================================
# Fixes Bitsandbytes to remove missing warnings
from transformers.utils.quantization_config import (
    BitsAndBytesConfig,
    QuantizationMethod,
)

BitsAndBytesConfig__init__ = inspect.getsource(BitsAndBytesConfig.__init__)
BitsAndBytesConfig__init__ = re.sub(
    r"if[\s]{1,}kwargs\:[\s]{1,}.+?\n",
    "",
    BitsAndBytesConfig__init__,
    flags = re.MULTILINE,
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.split("\n")
length_spaces = len(re.match(r"[\s]{1,}", BitsAndBytesConfig__init__[0]).group(0))
BitsAndBytesConfig__init__ = "\n".join(
    x[length_spaces:] for x in BitsAndBytesConfig__init__
)
BitsAndBytesConfig__init__ = BitsAndBytesConfig__init__.replace(
    "__init__",
    "_BitsAndBytesConfig__init__",
)
exec(BitsAndBytesConfig__init__, globals())

if DEVICE_COUNT == 1:
    from accelerate.utils.dataclasses import DistributedType

    def _prepare_backend(self, *args, **kwargs):
        return None, DistributedType.NO

    import accelerate.state

    accelerate.state.PartialState._prepare_backend = _prepare_backend
    accelerate.accelerator.Accelerator.distributed_type = (
        lambda *args, **kwargs: DistributedType.NO
    )


# to move multiple tensors to the same device
def move_to_device(target_device, *tensors):
    """
    Move multiple tensors to target device if they're not already there.

    Args:
        target_device: The target device to move tensors to
        *tensors: Variable number of tensors to potentially move

    Returns:
        tuple: The tensors on the target device (same objects if already on device, new if moved)
    """
    if isinstance(target_device, int):
        target_device = torch.device(target_device)
    elif isinstance(target_device, str):
        # if string we expect it to be a device name like "cuda:0"
        target_device = torch.device(target_device)
    elif isinstance(target_device, torch.device):
        pass
    else:
        raise ValueError(f"Invalid target device: {target_device}")
    moved_tensors = []
    for tensor in tensors:
        if tensor.device != target_device:
            moved_tensors.append(tensor.to(target_device))
        else:
            moved_tensors.append(tensor)
    return tuple(moved_tensors) if len(moved_tensors) > 1 else moved_tensors[0]


import transformers.utils.quantization_config

transformers.utils.quantization_config.BitsAndBytesConfig.__init__ = (
    _BitsAndBytesConfig__init__
)
# =============================================

# Offloading to disk for modules (lm_head, embed_tokens)
import pickle


def offload_to_disk(
    W, model, name, temporary_location: str = "_unsloth_temporary_saved_buffers"
):
    file_location = os.path.join(temporary_location, model.config._name_or_path)
    if not os.path.exists(file_location):
        os.makedirs(file_location)

    filename = os.path.join(file_location, f"{name}.pt")
    W = W.weight if hasattr(W, "weight") else W
    torch.save(
        W,
        filename,
        pickle_module = pickle,
        pickle_protocol = pickle.HIGHEST_PROTOCOL,
    )
    # We must use weights_only = False due to pickling
    offloaded_W = torch.load(
        filename, map_location = "cpu", mmap = True, weights_only = False
    )
    offloaded_W._offloaded_file_location = filename
    return offloaded_W


def offload_input_embeddings(
    model, temporary_location: str = "_unsloth_temporary_saved_buffers"
):
    offloaded_W = offload_to_disk(
        model.get_input_embeddings(), model, "input_embeddings", temporary_location
    )
    new_input_embeddings = torch.nn.Embedding.from_pretrained(offloaded_W)
    new_input_embeddings._offloaded_file_location = offloaded_W._offloaded_file_location
    model.set_input_embeddings(new_input_embeddings)
    return


def offload_output_embeddings(
    model, temporary_location: str = "_unsloth_temporary_saved_buffers"
):
    offloaded_W = offload_to_disk(
        model.get_output_embeddings(), model, "output_embeddings", temporary_location
    )

    new_output_embeddings = torch.nn.Linear(1, 1, bias = None)
    del new_output_embeddings.weight
    new_output_embeddings.weight = offloaded_W
    new_output_embeddings.in_features = offloaded_W.shape[1]
    new_output_embeddings.out_features = offloaded_W.shape[0]

    new_output_embeddings._offloaded_file_location = (
        offloaded_W._offloaded_file_location
    )
    model.set_output_embeddings(new_output_embeddings)
    return


# Fixes a weird Torch 2.3 bug which says T4s have bfloat16
def is_bfloat16_supported():
    return SUPPORTS_BFLOAT16


def is_vLLM_available():
    return _is_package_available("vllm")


# Patches models to add RoPE Scaling
def patch_linear_scaling(
    model_name = "gemma2",
    rope_module = None,
    scaled_rope_module = None,
    attention_module = None,
):
    assert rope_module is not None and scaled_rope_module is not None
    assert attention_module is not None

    rope_name = rope_module.__name__
    scaled_rope_name = scaled_rope_module.__name__
    model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
    exec_code = (
        f"import torch.nn as nn\n"
        f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
        f"from {model_filepath} import logger, "
        f"{model_name.title()}Attention, {model_name.title()}Config"
    )

    try:
        function = inspect.getsource(attention_module.__init__)
    except:
        # Most likely already patched!
        return None, None
    where = function.find("def")
    function = function.split("\n")
    function = "\n".join(x[where:] for x in function)
    init_name = f"{model_name.title()}Attention__init__"
    function = function.replace("def __init__", f"def {init_name}")
    function = function.replace(
        "super().__init__()",
        f"super({model_name.title()}Attention, self).__init__()",
    )
    fix_rope_function = """
    if getattr(self.config, "rope_scaling", None) is None:
        self.rotary_emb = {rope_function}(
            dim = self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
    else:
        scaling_type = self.config.rope_scaling["type"]
        scaling_factor = self.config.rope_scaling["factor"]
        if scaling_type == "linear":
            self.rotary_emb = {scaled_rope_function}(
                dim = self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                scaling_factor=scaling_factor,
                base=self.rope_theta,
            )
        else:
            raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
    pass
    """
    fix_rope_function = fix_rope_function.format(
        rope_function = rope_module.__name__,
        scaled_rope_function = scaled_rope_module.__name__,
    )
    rotary_emb = re.findall(
        r"self\.rotary\_emb \= .+?\)",
        function,
        flags = re.DOTALL | re.MULTILINE,
    )
    if len(rotary_emb) == 0:
        return None, exec_code + "\n\n" + function

    rotary_emb = rotary_emb[0]
    function = function.replace(rotary_emb, fix_rope_function, 1)
    function = exec_code + "\n\n" + function
    return init_name, function


# Patches for Llama-3 LlamaExtendedRotaryEmbedding
def patch_llama_rope_scaling(
    model_name = "llama",
    rope_module = None,
    scaled_rope_module = None,
    extended_rope_module = None,
    attention_module = None,
    longrope_module = None,
):
    assert (
        rope_module is not None
        and scaled_rope_module is not None
        and extended_rope_module is not None
    )
    assert attention_module is not None

    rope_name = rope_module.__name__
    scaled_rope_name = scaled_rope_module.__name__
    model_filepath = f"transformers.models.{model_name}.modeling_{model_name}"
    exec_code = (
        f"import torch.nn as nn\n"
        f"from typing import Union, Optional, List, Any, Callable, Tuple\n"
        f"from {model_filepath} import logger, "
        f"{model_name.title()}Attention, {model_name.title()}Config"
    )

    try:
        function = inspect.getsource(attention_module.__init__)
    except:
        # Most likely already patched!
        return None, None
    where = function.find("def")
    function = function.split("\n")
    function = "\n".join(x[where:] for x in function)
    init_name = f"{model_name.title()}Attention__init__"
    function = function.replace("def __init__", f"def {init_name}")
    function = function.replace(
        "super().__init__()",
        f"super({model_name.title()}Attention, self).__init__()",
    )
    fix_rope_function = """
    if getattr(self.config, "rope_scaling", None) is None:
        self.rotary_emb = {rope_function}(
            dim = self.head_dim,
            max_position_embeddings=self.max_position_embeddings,
            base=self.rope_theta,
        )
    else:
        scaling_type1 = self.config.rope_scaling.get("type", None)
        scaling_type2 = self.config.rope_scaling.get("rope_type", None)
        scaling_type = scaling_type1 if scaling_type1 is not None else scaling_type2
        scaling_factor = self.config.rope_scaling.get("factor")

        if scaling_type == "linear":
            self.rotary_emb = {scaled_rope_function}(
                dim = self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                scaling_factor=scaling_factor,
                base=self.rope_theta,
            )
        elif scaling_type == "llama3":
            self.rotary_emb = {extended_rope_function}(
                dim = self.head_dim,
                max_position_embeddings=self.max_position_embeddings,
                base=self.rope_theta,
            )
        elif scaling_type == "longrope":
            self.rotary_emb = {longrope_rope_function}(
                dim = self.head_dim,
                max_position_embeddings = self.max_position_embeddings,
                original_max_position_embeddings = self.config.original_max_position_embeddings,
                base = self.rope_theta,
                short_factor = self.config.rope_scaling['short_factor'],
                long_factor  = self.config.rope_scaling['long_factor' ],
            )
        else:
            raise ValueError(f"Unknown RoPE scaling type {{scaling_type}}")
    pass
    """

    fix_rope_function = fix_rope_function.format(
        rope_function = rope_module.__name__,
        scaled_rope_function = scaled_rope_module.__name__,
        extended_rope_function = extended_rope_module.__name__,
        longrope_rope_function = (
            longrope_module if longrope_module is not None else rope_module
        ).__name__,
    )
    rotary_emb = re.findall(
        r"self\.rotary\_emb \= .+?\)",
        function,
        flags = re.DOTALL | re.MULTILINE,
    )
    if len(rotary_emb) == 0:
        return None, function
    rotary_emb = rotary_emb[0]
    function = function.replace(rotary_emb, fix_rope_function, 1)
    function = exec_code + "\n\n" + function
    return init_name, function


def create_boolean_mask(n = 4096, sliding_window = 2048):
    # Creates a boolean mask for attention
    mask = torch.ones(n, n, dtype = torch.bool)
    if sliding_window == 0:
        return torch.triu(mask, diagonal = 1, out = mask)
    torch.triu(mask, diagonal = 0, out = mask)
    torch.triu(mask.T, diagonal = -sliding_window, out = mask.T)
    mask = mask.T
    torch.logical_not(mask, out = mask)
    return mask


def test_mask_creation():
    from transformers.modeling_attn_mask_utils import AttentionMaskConverter

    for n in range(2, 23):
        for s in range(1, 23):
            correct_mask = (
                AttentionMaskConverter(
                    is_causal = True,
                    sliding_window = s,
                )
                .to_causal_4d(
                    1,
                    n,
                    n,
                    dtype = torch.float16,
                )
                .squeeze(0)
                .squeeze(0)
            )
            correct_mask = correct_mask == correct_mask.min()
            our_mask = create_boolean_mask(n = n, sliding_window = s)
            assert torch.all(correct_mask == our_mask)
        correct_mask = (
            AttentionMaskConverter(
                is_causal = True,
                sliding_window = None,
            )
            .to_causal_4d(
                1,
                n,
                n,
                dtype = torch.float16,
            )
            .squeeze(0)
            .squeeze(0)
        )
        correct_mask = correct_mask == correct_mask.min()
        our_mask = create_boolean_mask(n = n, sliding_window = 0)
        assert torch.all(correct_mask == our_mask)


def _unsloth_pre_compute_loss(self, model, inputs, *args, **kwargs):
    num_items_in_batch = None

    if "num_items_in_batch" in kwargs:
        num_items_in_batch = kwargs["num_items_in_batch"]
        if num_items_in_batch is None:
            # Remove it since the model does not support it!
            kwargs.pop("num_items_in_batch")
        elif "num_items_in_batch" not in inputs:
            inputs["num_items_in_batch"] = num_items_in_batch

    # Get gradient accumulation steps if possible
    if (
        num_items_in_batch is None
        and getattr(getattr(self, "args", self), "gradient_accumulation_steps", 1) != 1
    ):
        inner_model = model
        if hasattr(inner_model, "base_model"):
            inner_model = inner_model.base_model
        if hasattr(inner_model, "model"):
            inner_model = inner_model.model
        name = inner_model.__class__.__name__

        logger.warning_once(
            f"Unsloth: Not an error, but {name} does not accept `num_items_in_batch`.\n"
            "Using gradient accumulation will be very slightly less accurate.\n"
            "Read more on gradient accumulation issues here: https://unsloth.ai/blog/gradient"
        )
    outputs = self._old_compute_loss(model, inputs, *args, **kwargs)
    return outputs


def patch_gradient_accumulation_fix(Trainer):
    # Fixes gradient accumulation
    # Fixes Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.
    import inspect

    if hasattr(Trainer, "get_batch_samples"):
        if Trainer.get_batch_samples.__name__ == "_unsloth_get_batch_samples":
            return
        if (
            not inspect.getsource(Trainer.get_batch_samples)
            .strip()
            .endswith("return batch_samples, num_items_in_batch")
        ):
            raise NotImplementedError(
                "Unsloth: Please make a Github issue immediately!!"
            )
        else:
            if Trainer.get_batch_samples.__name__ != "_unsloth_get_batch_samples":
                Trainer.get_batch_samples = _unsloth_get_batch_samples

            # Also fix passing in num_items_in_batch
            if not hasattr(Trainer, "_old_compute_loss"):
                # Fix transformers 4.57.0 causing `Output 0 of UnslothFusedLossBackward is a view and is being modified inplace.`
                function = inspect.getsource(Trainer.compute_loss)
                if "loss *=" in function or "loss*=" in function:
                    where = function.find("def")
                    function = function.split("\n")
                    function = "\n".join(x[where:] for x in function)

                    # Import all variables that need importing
                    import transformers.trainer

                    items_in_trainer = dir(transformers.trainer)
                    good_items = []
                    for item in items_in_trainer:
                        if item in function:
                            good_items.append(item)
                    exec(
                        "from transformers.trainer import ("
                        + ", ".join(x for x in good_items)
                        + ")",
                        globals(),
                    )

                    # Replace loss*= with loss = loss *
                    function = re.sub(
                        r"loss[\s]{0,}\*\=",
                        "loss = loss *",
                        function,
                    )
                    exec(function, globals())
                    Trainer.compute_loss = compute_loss
                Trainer._old_compute_loss = Trainer.compute_loss
                Trainer.compute_loss = _unsloth_pre_compute_loss
    else:
        logger.warning_once(
            "Unsloth: We fixed a gradient accumulation bug, "
            "but it seems like you don't have the latest transformers version!\n"
            "Please update transformers, TRL and unsloth via:\n"
            "`pip install --upgrade --no-cache-dir --no-deps unsloth transformers git+https://github.com/huggingface/trl.git`"
        )

    # Also fix up loss scaling ie negate loss *= self.args.gradient_accumulation_steps
    if not (
        Trainer.training_step.__name__ == "_unsloth_training_step"
        or "num_items_in_batch"
        not in inspect.signature(Trainer.training_step).parameters
    ):
        function = inspect.getsource(Trainer.training_step)
        where = function.find("def")
        function = function.split("\n")
        function = "\n".join(x[where:] for x in function)

        # Import all variables that need importing
        import transformers.trainer

        items_in_trainer = dir(transformers.trainer)
        good_items = []
        for item in items_in_trainer:
            if item in function:
                good_items.append(item)
        exec(
            "from transformers.trainer import ("
            + ", ".join(x for x in good_items)
            + ")",
            globals(),
        )

        # Accelerate does / self.args.gradient_accumulation_steps internally, so if we already
        # summed it up and did the division before hand, we have to negate it.
        function = function.replace(
            "loss *= self.args.gradient_accumulation_steps",
            "if num_items_in_batch is not None: loss *= self.args.gradient_accumulation_steps",
        )
        function = function.replace(
            "def training_step", "def _unsloth_training_step", 1
        )

        # Fix 4.47.0 issue where num_items_in_batch was removed
        # See https://github.com/huggingface/transformers/pull/35121
        function = function.replace(
            "if self.model_accepts_loss_kwargs:",
            "if False:",
        )

        # Fix when num_items_in_batch is nothing
        # https://github.com/huggingface/transformers/pull/35207
        function = re.sub(
            r"else:\n"
            r"([\s]{4,})self\.accelerator\.backward\(loss, \*\*kwargs\)\n"
            r"(.+?)if num_items_in_batch is None\:\n"
            r"(.+?)return loss\.detach\(\) \/ self\.args\.gradient_accumulation_steps",
            "else:\n"
            "\2if num_items_in_batch is None:\n"
            "\3loss = loss / self.args.gradient_accumulation_steps\n"
            "\1self.accelerator.backward(loss, **kwargs)",
            function,
        )

        exec(function, globals())
        Trainer.training_step = _unsloth_training_step

    # Prevent double scaling gradient accumulation
    # https://github.com/huggingface/transformers/pull/37208
    # Patch model_accepts_loss_kwargs detection in Trainer.__init__
    if Trainer.__init__.__name__ != "_unsloth___init__":
        try:
            init_function = inspect.getsource(Trainer.__init__)
        except Exception:
            init_function = ""
        if init_function is not None:
            init_function = textwrap.dedent(init_function)

            # Import all variables that need importing
            import transformers.trainer

            items_in_trainer = dir(transformers.trainer)
            good_items = []
            for item in items_in_trainer:
                if item in init_function:
                    good_items.append(item)
            exec(
                "from transformers.trainer import ("
                + ", ".join(x for x in good_items)
                + ")",
                globals(),
            )

            init_function = init_function.replace(
                "def __init__", "def _unsloth___init__", 1
            )

            # Force else branch
            init_function = re.sub(
                r'if[\s]+hasattr\(\s*unwrapped_model\s*,\s*"accepts_loss_kwargs"\s*\)\s*:',
                'if hasattr(unwrapped_model, "accepts_loss_kwargs") and False:',
                init_function,
            )
            exec(init_function, globals())
            Trainer.__init__ = _unsloth___init__


def patch_tokenizer(model, tokenizer):
    model, tokenizer = _patch_tokenizer(model, tokenizer)
    if model is not None:
        model.config.update({"unsloth_version": __version__})
    return model, tokenizer


def patch_fast_lora():
    import peft.tuners.lora.bnb

    peft.tuners.lora.bnb.Linear4bit.forward = fast_lora_forward


def unsloth_compile_transformers(
    dtype,
    model_name,
    model_types,
    token = None,
    revision = None,
    trust_remote_code = False,
    sdpa_dynamic_mask = True,
    sdpa_bool_masks = True,
    sdpa_gqa_replace = True,
    sdpa_dynamic_compile = True,
    compile_attention = True,
    disable_causal_masks = True,
    compile_torch_modules = True,
    compile_custom_modules = True,
    compile_function_calls = True,
    fuse_lm_head = True,
    gradient_checkpointing = True,
    manual_replacements = True,
    fast_lora_forwards = True,
    fast_residual_stream = True,
    accurate_accumulation = True,
    epilogue_fusion = True,
    max_autotune = False,
    shape_padding = True,
    cudagraphs = False,
    debug = False,
    fullgraph = True,
    import_from_cache = False,
    disable = False,
    return_logits = False,
    unsloth_force_compile = False,
):
    if Version(torch_version) < Version("2.4.0"):
        print(
            "="
            * 30
            + "Unsloth: Unfortunately Unsloth vision and other newer optimized models need Torch 2.4 or later.\n"
            f"You have Torch version {torch_version}. Please upgrade your Torch version by visiting https://pytorch.org/\n"
            "For now your models will not get optimized, but will still work for now!"
        )
        return
    if trust_remote_code and unsloth_force_compile == False:
        print(
            "Unsloth: We can't trace models if `trust_remote_code = True`, "
            "so turning off some optimizations!"
        )
        return model_types, False
    model_types = list(dict().fromkeys(model_types).keys())
    if disable:
        return model_types, False

    supports_sdpa = [True]
    for model_type in model_types:
        _unsloth_compile_transformers(
            model_type,
            sdpa_dynamic_mask = sdpa_dynamic_mask,
            sdpa_bool_masks = sdpa_bool_masks,
            sdpa_gqa_replace = sdpa_gqa_replace,
            sdpa_dynamic_compile = sdpa_dynamic_compile,
            compile_attention = compile_attention,
            disable_causal_masks = disable_causal_masks,
            compile_torch_modules = compile_torch_modules,
            compile_custom_modules = compile_custom_modules,
            compile_function_calls = compile_function_calls,
            fuse_lm_head = fuse_lm_head,
            gradient_checkpointing = gradient_checkpointing,
            manual_replacements = manual_replacements,
            fast_lora_forwards = fast_lora_forwards,
            fast_residual_stream = fast_residual_stream,
            accurate_accumulation = accurate_accumulation,
            epilogue_fusion = epilogue_fusion,
            max_autotune = max_autotune,
            shape_padding = shape_padding,
            cudagraphs = cudagraphs,
            debug = debug,
            fullgraph = fullgraph,
            import_from_cache = import_from_cache,
            disable = disable,
            return_logits = return_logits,
            supports_sdpa = supports_sdpa,
        )
    # Redo patches which override compiler
    for temporary_patch in TEMPORARY_PATCHES:
        temporary_patch()
    return model_types, supports_sdpa[0]


# We need an empty logits flag to warn people logits will not be returned anymore unless asked ie
# os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
LOGITS_ERROR_STRING = (
    "Unsloth: Logits are empty from 2024.11 onwards. To get raw logits again, please "
    'set the environment variable `UNSLOTH_RETURN_LOGITS` to `"1" BEFORE starting to train ie before `trainer.train()`. For example:\n'
    "```\nimport os\n"
    "os.environ['UNSLOTH_RETURN_LOGITS'] = '1'\n"
    "trainer.train()\n```\n"
    "No need to restart your console - just add `os.environ['UNSLOTH_RETURN_LOGITS'] = '1'` before trainer.train() and re-run the cell!"
)


def raise_logits_error(*args, **kwargs):
    raise NotImplementedError(LOGITS_ERROR_STRING)


def return_none(*args, **kwargs):
    return None


class EmptyLogits:
    def __init__(self):
        return

    def raise_getattr_error(self, attr):
        return return_none if attr == "to" else raise_logits_error

    __getitem__ = raise_logits_error
    __getattr__ = raise_getattr_error

    def __repr__(self):
        return LOGITS_ERROR_STRING

    def __str__(self):
        return LOGITS_ERROR_STRING


EMPTY_LOGITS = EmptyLogits()
functions = dir(torch.Tensor)
for j, function in enumerate(functions):
    if function.startswith("__") and function.endswith("__"):
        exec(
            f"def raise_{j}(*args, **kwargs): print('{function}')", globals(), locals()
        )
        try:
            exec(f"EMPTY_LOGITS.{function} = raise_{j}", globals(), locals())
        except:
            continue


def validate_loftq_config(loftq_config, lora_dropout, bias, init_lora_weights, model):
    from peft import LoraConfig

    if loftq_config is None:
        loftq_config = {}

    signature = str(inspect.signature(LoraConfig))
    SUPPORTS_LOFTQ = "loftq_config" in signature

    if lora_dropout != 0:
        logger.warning_once(
            f"Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = {lora_dropout}.\n"
            f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
        )

    if bias != "none":
        logger.warning_once(
            f"Unsloth: bias = `none` is supported for fast patching. You are using bias = {bias}.\n"
            f"Unsloth will patch all other layers, except LoRA matrices, causing a performance hit."
        )

    if not (
        type(init_lora_weights) is bool
        or init_lora_weights == "gaussian"
        or init_lora_weights == "loftq"
        or init_lora_weights == "corda"
    ):
        raise ValueError(
            'Unsloth: `init_lora_weights` must be either [True, False, "gaussian", "loftq", "corda"].'
        )

    if init_lora_weights == "loftq":
        if not SUPPORTS_LOFTQ:
            import peft

            raise RuntimeError(
                f"Unsloth: Your PEFT version of {peft.__version__} does not support LoftQ init.\n"
                "Please install PEFT 0.7.2 or higher.\n"
                "You can also install from source: `pip install git+https://github.com/huggingface/peft.git"
            )

        if loftq_config == {}:
            from peft import LoftQConfig

            logger.warning_once(
                "Unsloth: init_lora_weights = `loftq` is set, but `loftq_config` is None.\n"
                "We shall use `loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)`."
            )
            loftq_config = LoftQConfig(loftq_bits = 4, loftq_iter = 1)

        if hasattr(model.config, "quantization_config"):
            raise ValueError(
                "Unsloth: You are using `loftq` init, yet `load_in_4bit = True` was set.\n"
                "Reload your model without any quantization by setting `load_in_4bit = False`."
            )

    return loftq_config


def fast_inference_setup(model_name, model_config):
    fast_inference = True
    if not is_vLLM_available():
        logger.warning_once(
            "Unsloth: vLLM is not installed! Will use Unsloth inference!"
        )
        fast_inference = False
    from unsloth_zoo.vllm_utils import (
        patch_vllm,
        vllm_dynamic_quant_supported,
    )

    patch_vllm()
    if model_name.endswith("unsloth-bnb-4bit"):
        if not vllm_dynamic_quant_supported(model_name, model_config):
            # Instead use -bnb-4bit variant
            logger.warning_once(
                f"Unsloth: Switching from Unsloth dynamic quant to normal quant since\n"
                f"we do not yet support fast inference for {model_name}"
            )
            model_name = model_name[: -len("unsloth-bnb-4bit")] + "bnb-4bit"
    return fast_inference, model_name


def patch_peft_fast_inference(model):
    vllm_engine = getattr(model.model, "vllm_engine", None)
    if vllm_engine is not None:
        model.vllm_engine = model.model.vllm_engine
        model.fast_generate = model.model.fast_generate
        model.fast_generate_batches = model.model.fast_generate_batches

        # Also saving and loading LoRA
        from unsloth_zoo.vllm_utils import save_lora, load_lora

        model.save_lora = functools.partial(save_lora, model)
        model.load_lora = functools.partial(load_lora, model)


def error_out_no_vllm(*args, **kwargs):
    raise NotImplementedError(
        "Unsloth: vLLM is not yet supported for fast inference for this model! Please use `.generate` instead"
    )


try:
    from torchao.core.config import AOBaseConfig

    try:
        from torchao.quantization import Int4WeightOnlyConfig
    except:
        print("Unsloth: TorchAO changed `torchao.quantization.Int4WeightOnlyConfig`")
        Int4WeightOnlyConfig = None
except:
    AOBaseConfig = None
    Int4WeightOnlyConfig = None


@dataclass
class TorchAOConfig:
    qat_scheme: Optional[str] = "int4"

    # Each (config, filter_fn) pair defines a quantization rule
    base_config_and_filter_fns: List[
        Tuple["AOBaseConfig", Optional[Callable[[torch.nn.Module, str], bool]]]
    ] = field(
        default_factory = lambda: [
            (
                Int4WeightOnlyConfig(group_size = 128),
                lambda m, _: isinstance(m, torch.nn.Linear)
                and getattr(m, "in_features", 0) >= 128,
            ),
        ]
    )

    # Optional transformation to apply before quantization setup
    prequantization_transform: Optional[Callable[[torch.nn.Module], None]] = None


def _untie_input_output_embeddings(model: torch.nn.Module) -> None:
    """
    Utility to untie input/output embeddings in a HuggingFace model.
    This is useful if we want to quantize the input/ouput embeddings differently.
    Model is modified in-place.
    """

    # 1) Persist setting in config
    if hasattr(model.config, "tie_word_embeddings"):
        model.config.tie_word_embeddings = False

    # 2) Find input and output embeddings
    in_emb = model.get_input_embeddings()
    out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
    if out_proj is None:
        raise AttributeError("Couldn't locate output projection (lm_head).")

    # (Optional) sanity: shapes should match [vocab, hidden]
    assert (
        out_proj.weight.shape == in_emb.weight.shape
    ), f"Shape mismatch: out_proj {out_proj.weight.shape} vs in_emb {in_emb.weight.shape}"

    # 3) Only clone if they are actually tied (shared storage)
    if out_proj.weight.data_ptr() == in_emb.weight.data_ptr():
        with torch.no_grad():
            W = in_emb.weight.detach().clone()
        out_proj.weight = torch.nn.Parameter(W)  # new storage, keeps dtype/device

    # 4) Prevent future automatic re-tying
    def _no_tie(self):
        return

    model.tie_weights = _no_tie.__get__(model, model.__class__)

    # 5) Verify no shared storage
    assert (
        out_proj.weight.data_ptr() != in_emb.weight.data_ptr()
    ), "Embeddings still tied!"


def _filter_fn_to_fqns(
    model: torch.nn.Module,
    filter_fn: Callable[[torch.nn.Module, str], bool],
) -> Iterator[str]:
    """
    Given a model and a filter function (m, fqn) -> bool,
    yield fully qualified names (FQNs) of modules that match.
    """
    for fqn, module in model.named_modules():
        if filter_fn(module, fqn):
            yield fqn


def _convert_torchao_model(model):
    from transformers import TorchAoConfig
    from torchao.quantization import quantize_, ModuleFqnToConfig
    from torchao.quantization.qat import QATConfig
    from torchao.utils import TorchAOBaseTensor

    module_to_fqn_dict = {}
    for base_config, filter_fn in model._torchao_config.base_config_and_filter_fns:
        quantize_(model, QATConfig(base_config, step = "convert"), filter_fn = filter_fn)

        # Default filter function used for quantize_
        if filter_fn is None:
            if "_default" in module_to_fqn_dict:
                raise ValueError("Cannot use multiple default quantization configs")
            module_to_fqn_dict["_default"] = base_config
        else:
            for fqn in _filter_fn_to_fqns(model, filter_fn):
                if fqn in module_to_fqn_dict:
                    raise ValueError(f"Found multiple quantization configs for {fqn}")
                module_to_fqn_dict[fqn] = base_config

    in_emb = model.get_input_embeddings()
    out_proj = model.get_output_embeddings() or getattr(model, "lm_head", None)
    kwargs = {}
    if isinstance(in_emb.weight, TorchAOBaseTensor) or (
        out_proj is not None and isinstance(out_proj.weight, TorchAOBaseTensor)
    ):
        kwargs["include_input_output_embeddings"] = True
        kwargs["modules_to_not_convert"] = []

    quant_config = ModuleFqnToConfig(module_to_fqn_dict)
    quantization_config = TorchAoConfig(quant_type = quant_config, **kwargs)
    model.config.quantization_config = quantization_config


def _prepare_model_for_qat(
    model: torch.nn.Module, qat_scheme: Union[str, TorchAOConfig]
) -> torch.nn.Module:
    """
    Transform a model for Quantization-Aware Training (QAT) during fine-tuning.

    On a high level, this means fake quantizing the base (frozen) model during training.
    Fake quantization refers to simulating quantization numerics in high precision (e.g. bf16).
    This helps mitigate quantization degradations when the model is quantized after training.

    QAT can be optionally combined with LoRA fine-tuning to for additional throughput improvement.
    For more details: https://dev-discuss.pytorch.org/t/speeding-up-qat-by-1-89x-with-lora/2700
    """
    try:
        from torchao.quantization import PerRow, quantize_
        from torchao.quantization.granularity import PerGroup, PerAxis
        from torchao.quantization.qat import QATConfig
    except ImportError:
        raise ImportError(TORCHAO_MSG)

    # Gemma3 models have issues with int8 embedding quantization due to their
    # large vocabulary size (262144). Auto-switch to int4 weight-only instead.
    if qat_scheme == "int8-int4":
        model_types = get_transformers_model_type(model.config)
        is_gemma3 = any("gemma3" in mt or "gemma_3" in mt for mt in model_types)
        if is_gemma3:
            print(
                "Unsloth: Gemma3 has a large vocabulary causing int8 embedding issues. "
                "Switching to int4 weight-only QAT for training stability."
            )
            qat_scheme = "int4"

    if not isinstance(qat_scheme, TorchAOConfig):
        torchao_config: Optional[TorchAOConfig] = None
        if qat_scheme == "fp8-int4":
            try:
                from torchao.quantization import Float8DynamicActivationInt4WeightConfig
            except ImportError:
                raise ImportError(TORCHAO_MSG)
            group_size = 128
            base_config = Float8DynamicActivationInt4WeightConfig()
            filter_fn = (
                lambda m, _: isinstance(m, torch.nn.Linear)
                and m.in_features >= group_size
            )
            torchao_config = TorchAOConfig(
                qat_scheme = qat_scheme,
                base_config_and_filter_fns = [(base_config, filter_fn)],
            )
        elif qat_scheme == "fp8-fp8":
            try:
                from torchao.quantization import (
                    Float8DynamicActivationFloat8WeightConfig,
                )
            except ImportError:
                raise ImportError(TORCHAO_MSG)
            base_config = Float8DynamicActivationFloat8WeightConfig(
                granularity = PerRow()
            )
            torchao_config = TorchAOConfig(
                qat_scheme = qat_scheme, base_config_and_filter_fns = [(base_config, None)]
            )
        elif qat_scheme == "int8-int4":
            try:
                from torchao.quantization import (
                    Int8DynamicActivationIntxWeightConfig,
                    IntxWeightOnlyConfig,
                )
            except ImportError:
                raise ImportError(TORCHAO_MSG)
            torchao_config = TorchAOConfig(
                qat_scheme = qat_scheme,
                base_config_and_filter_fns = [
                    (
                        IntxWeightOnlyConfig(
                            weight_dtype = torch.int8, granularity = PerAxis(0)
                        ),
                        lambda m, fqn: isinstance(m, torch.nn.Embedding),
                    ),
                    (
                        Int8DynamicActivationIntxWeightConfig(
                            weight_dtype = torch.int4, weight_granularity = PerGroup(32)
                        ),
                        None,
                    ),
                ],
                prequantization_transform = _untie_input_output_embeddings,
            )
        elif qat_scheme == "int4":
            try:
                from torchao.quantization import Int4WeightOnlyConfig
            except ImportError:
                raise ImportError(TORCHAO_MSG)
            group_size = 128
            base_config = Int4WeightOnlyConfig(group_size = group_size)
            filter_fn = (
                lambda m, _: isinstance(m, torch.nn.Linear)
                and m.in_features >= group_size
            )
            torchao_config = TorchAOConfig(
                qat_scheme = qat_scheme,
                base_config_and_filter_fns = [(base_config, filter_fn)],
            )
        elif qat_scheme == "int8":
            try:
                from torchao.quantization import IntxWeightOnlyConfig
                from torchao.quantization.granularity import PerAxis
            except ImportError:
                raise ImportError(TORCHAO_MSG)

            base_config = IntxWeightOnlyConfig(
                weight_dtype = torch.int8,
                granularity = PerAxis(0),
            )
            filter_fn = lambda m, _: isinstance(m, torch.nn.Linear)
            torchao_config = TorchAOConfig(
                qat_scheme = qat_scheme,
                base_config_and_filter_fns = [(base_config, filter_fn)],
            )
        else:
            raise ValueError(f"Unexpected QAT scheme {qat_scheme}")
        assert torchao_config is not None, f"TorchAOConfig was not set for {qat_scheme}"
    else:
        torchao_config = qat_scheme

    # Save Torchao metadata everywhere
    inner_model = model
    while hasattr(inner_model, "model"):
        inner_model._torchao_config = torchao_config
        inner_model = inner_model.model
    inner_model._torchao_config = torchao_config

    if torchao_config.prequantization_transform is not None:
        torchao_config.prequantization_transform(model)
    for base_config, filter_fn in torchao_config.base_config_and_filter_fns:
        quantize_(model, QATConfig(base_config, step = "prepare"), filter_fn = filter_fn)

    return model


def patch_hf_quantizer():
    # To tell hf trainer that the quantized model is trainable
    def make_trainable(self):
        return True

    try:
        from transformers.quantizers.quantizer_finegrained_fp8 import (
            FineGrainedFP8HfQuantizer,
        )

        FineGrainedFP8HfQuantizer.is_trainable = property(make_trainable)
        FineGrainedFP8HfQuantizer.is_qat_trainable = property(make_trainable)
    except Exception as e:
        logger.warning(f"Failed to patch FineGrainedFP8HfQuantizer. Error {e}")

    try:
        from transformers.quantizers.quantizer_fbgemm_fp8 import FbgemmFp8HfQuantizer

        FbgemmFp8HfQuantizer.is_trainable = property(make_trainable)
        FbgemmFp8HfQuantizer.is_qat_trainable = property(make_trainable)
    except Exception as e:
        logger.warning(f"Failed to patch FbgemmFp8HfQuantizer. Error {e}")


patch_hf_quantizer()


def verify_fp8_support_if_applicable(model_config):
    quant_method = get_quant_type(model_config)
    if quant_method in ["fbgemm_fp8", "fp8"] and DEVICE_TYPE != "cuda":
        raise ValueError(
            f"Unsloth: FP8 quantization is only supported on CUDA GPUs. You are using {DEVICE_TYPE}."
        )

    # [TODO] Need to add FP8 support for Intel XPUs
    if DEVICE_TYPE == "cuda":
        major_version, minor_version = torch.cuda.get_device_capability()
        if quant_method == "fbgemm_fp8" and major_version < 9:
            # While L4 does support FP8 as data type, it doesn't have fbgemm (package) support yet. So we restrict it.
            raise ValueError(
                f"Unsloth: FBGEMM FP8 quantization is only supported on H100 and higher GPUs. L4 is not supported. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
            )
        if quant_method == "fp8" and major_version * 10 + minor_version < 89:
            # In case of block quantized, we allow L4 because we fall back to torchao kernels.
            raise ValueError(
                f"Unsloth: FP8 quantization is only supported on L4 and higher GPUs with compute capability 8.9 or higher. You are using {torch.cuda.get_device_name()}. Refer to https://developer.nvidia.com/cuda-gpus for more details."
            )


def _get_inference_mode_context_manager(model: torch.nn.Module):
    """
    If the state dict was quantized using torchao, we will run into
    the following error when calling ops like aten.t() in inference mode.
    This is a bug in PyTorch that affects all tensor subclasses.

        Cannot set version_counter for inference tensor

    For now, we work around this issue by using `torch.no_grad()` in this case.
    See https://github.com/pytorch/pytorch/issues/164872 for more details.
    Otherwise, just return `torch.inference_mode()`.
    """
    torchao_config = getattr(model, "torchao_config", None)
    if torchao_config is not None and torchao_config.qat_scheme is None:
        return torch.no_grad()
    else:
        return torch.inference_mode()


def hf_login(token: Optional[str] = None) -> Optional[str]:
    if token is None:
        try:
            from huggingface_hub import get_token

            token = get_token()
            if token is None:
                return None
        except:
            return None
    try:
        from huggingface_hub import login

        login(token = token)
        return token
    except Exception as e:
        logger.info(f"Failed to login to huggingface using token with error: {e}")
    return token


def make_fast_generate_wrapper(original_generate):
    """
    Creates a wrapper around model.generate that checks for incorrect
    vLLM-style usage when fast_inference=False.
    """

    @functools.wraps(original_generate)
    def _fast_generate_wrapper(*args, **kwargs):
        # Check for vLLM-specific arguments
        if "sampling_params" in kwargs:
            raise ValueError(
                "Unsloth: `sampling_params` is only supported when `fast_inference=True` (vLLM). "
                "Since `fast_inference=False`, use HuggingFace generate arguments instead:\n"
                "  model.fast_generate(**tokens.to('cuda'), max_new_tokens=64, temperature=1.0, top_p=0.95)"
            )

        if "lora_request" in kwargs:
            raise ValueError(
                "Unsloth: `lora_request` is only supported when `fast_inference=True` (vLLM). "
                "Since `fast_inference=False`, LoRA weights are already merged into the model."
            )

        # Check if first positional argument is a string or list of strings
        if len(args) > 0:
            first_arg = args[0]
            is_string_input = False

            if isinstance(first_arg, str):
                is_string_input = True
            elif isinstance(first_arg, (list, tuple)) and len(first_arg) > 0:
                if isinstance(first_arg[0], str):
                    is_string_input = True

            if is_string_input:
                raise ValueError(
                    "Unsloth: Passing text strings to `fast_generate` is only supported "
                    "when `fast_inference=True` (vLLM). Since `fast_inference=False`, you must "
                    "tokenize the input first:\n\n"
                    "  messages = tokenizer.apply_chat_template(\n"
                    '      [{"role": "user", "content": "Your prompt here"}],\n'
                    "      tokenize=True, add_generation_prompt=True,\n"
                    '      return_tensors="pt", return_dict=True\n'
                    "  )\n"
                    "  output = model.fast_generate(\n"
                    "      **messages.to('cuda'),\n"
                    "      max_new_tokens=64,\n"
                    "      temperature=1.0,\n"
                    "  )"
                )

        # Call original generate
        return original_generate(*args, **kwargs)

    return _fast_generate_wrapper
