# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ._utils import (
    _prepare_model_for_qat,
    is_bfloat16_supported,
    is_vLLM_available,
    HAS_FLASH_ATTENTION,
    HAS_FLASH_ATTENTION_SOFTCAPPING,
    USE_MODELSCOPE,
    get_transformers_model_type,
    hf_login,
)
from .granite import FastGraniteModel
from .llama import FastLlamaModel, logger
from .mistral import FastMistralModel
from .qwen2 import FastQwen2Model
from .qwen3 import FastQwen3Model
from .qwen3_moe import FastQwen3MoeModel
from .cohere import FastCohereModel
from transformers import AutoConfig
from transformers import __version__ as transformers_version
from peft import PeftConfig, PeftModel
from .loader_utils import (
    _get_fp8_mode_and_check_settings,
    _offline_quantize_to_fp8,
    _tag_model_with_fp8_torchao_config,
    get_model_name,
)
import os, contextlib, sys

try:
    from huggingface_hub import get_token
except:
    try:
        from huggingface_hub.utils import get_token
    except:
        # For older versions of huggingface_hub
        from huggingface_hub.utils._token import get_token
from huggingface_hub import HfFileSystem
import importlib.util
from ..device_type import (
    is_hip,
    get_device_type,
    DEVICE_TYPE,
    DEVICE_TYPE_TORCH,
    DEVICE_COUNT,
    ALLOW_PREQUANTIZED_MODELS,
    ALLOW_BITSANDBYTES,
)

# https://github.com/huggingface/transformers/pull/26037 allows 4 bit loading!
from unsloth_zoo.utils import Version, _get_dtype
from unsloth_zoo.hf_utils import dtype_from_config
from unsloth_zoo.tiled_mlp import patch_tiled_mlp

transformers_version = Version(transformers_version)
SUPPORTS_FOURBIT = transformers_version >= Version("4.37")
SUPPORTS_GEMMA = transformers_version >= Version("4.38")
SUPPORTS_GEMMA2 = transformers_version >= Version("4.42")
SUPPORTS_LLAMA31 = transformers_version >= Version("4.43.2")
SUPPORTS_LLAMA32 = transformers_version > Version("4.45.0")
SUPPORTS_GRANITE = transformers_version >= Version("4.46.0")
SUPPORTS_QWEN3 = transformers_version >= Version("4.50.3")
SUPPORTS_QWEN3_MOE = transformers_version >= Version("4.50.3")
SUPPORTS_FALCON_H1 = transformers_version >= Version("4.53.0")
SUPPORTS_GEMMA3N = transformers_version >= Version("4.53.0")
SUPPORTS_GPTOSS = transformers_version >= Version("4.55.0")
if SUPPORTS_GEMMA:
    from .gemma import FastGemmaModel
if SUPPORTS_GEMMA2:
    from .gemma2 import FastGemma2Model
if SUPPORTS_FALCON_H1:
    from .falcon_h1 import FastFalconH1Model
import torch
from ._utils import (
    patch_compiling_bitsandbytes,
    patch_model_and_tokenizer,
    prepare_model_for_kbit_training,
    patch_unsloth_smart_gradient_checkpointing,
    patch_compiled_autograd,
    process_vision_info,
    unsloth_compile_transformers,
    fast_inference_setup,
)

global FORCE_FLOAT32
# Forces float32 precision since float16 goes to infinity
FORCE_FLOAT32 = [
    "gemma3,",  # Add comma bc gemma3 will match gemma3n
    "gemma3n",
    "gpt_oss",
]

global DISABLE_COMPILE_MODEL_NAMES
# Must be alphabetically sorted for each entry
DISABLE_COMPILE_MODEL_NAMES = [
    "aya_vision",
    "modernbert",
    "granite,llava_next",  # Granite-vision 3
]

global DISABLE_SDPA_MODEL_NAMES
# Disables some SDPA modules since it's wrong
DISABLE_SDPA_MODEL_NAMES = [
    "gemma3,",  # Add comma bc gemma3 will match gemma3n
]


class FastLanguageModel(FastLlamaModel):
    @staticmethod
    def from_pretrained(
        model_name = "unsloth/Llama-3.2-1B-Instruct",
        max_seq_length = 2048,
        dtype = None,
        load_in_4bit = True,  # 4bit QLoRA
        load_in_8bit = False,  # 8bit  LoRA
        load_in_16bit = False,  # 16bit LoRA
        full_finetuning = False,
        token = None,
        device_map = "sequential",
        rope_scaling = None,
        fix_tokenizer = True,
        trust_remote_code = False,
        use_gradient_checkpointing = "unsloth",
        resize_model_vocab = None,
        revision = None,
        use_exact_model_name = False,
        offload_embedding = False,
        float32_mixed_precision = None,  # Forces float32 mixed precision
        fast_inference = False,  # uses vLLM
        gpu_memory_utilization = 0.5,
        float8_kv_cache = False,
        random_state = 3407,
        max_lora_rank = 64,
        disable_log_stats = True,
        qat_scheme = None,
        load_in_fp8 = False,  # fp8 LoRA (True, False, 'block')
        unsloth_tiled_mlp = False,
        *args,
        **kwargs,
    ):
        # Respect user-provided quantization_config (e.g. BitsAndBytesConfig)
        quantization_config = kwargs.get("quantization_config", None)
        if quantization_config is not None:
            if isinstance(quantization_config, dict):
                q_load_in_4bit = quantization_config.get("load_in_4bit", False)
                q_load_in_8bit = quantization_config.get("load_in_8bit", False)
            else:
                q_load_in_4bit = getattr(quantization_config, "load_in_4bit", False)
                q_load_in_8bit = getattr(quantization_config, "load_in_8bit", False)
            if q_load_in_4bit:
                load_in_4bit = True
                load_in_8bit = False
            if q_load_in_8bit:
                load_in_8bit = True
                load_in_4bit = False

        # Login to allow private models
        token = hf_login(token)
        # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.
        if dtype is None and quantization_config is not None:
            bnb_compute_dtype = None
            if isinstance(quantization_config, dict):
                if quantization_config.get("load_in_4bit", False):
                    bnb_compute_dtype = quantization_config.get(
                        "bnb_4bit_compute_dtype", None
                    )
            else:
                if getattr(quantization_config, "load_in_4bit", False):
                    bnb_compute_dtype = getattr(
                        quantization_config, "bnb_4bit_compute_dtype", None
                    )
            if isinstance(bnb_compute_dtype, str):
                bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)
            if isinstance(bnb_compute_dtype, torch.dtype):
                dtype = bnb_compute_dtype
        if load_in_8bit or full_finetuning or qat_scheme is not None:
            return FastModel.from_pretrained(
                model_name = model_name,
                max_seq_length = max_seq_length,
                dtype = dtype,
                load_in_4bit = load_in_4bit,
                load_in_8bit = load_in_8bit,
                load_in_16bit = load_in_16bit,
                full_finetuning = full_finetuning,
                token = token,
                device_map = device_map,
                rope_scaling = rope_scaling,  # [TODO] No effect
                fix_tokenizer = fix_tokenizer,  # [TODO] No effect
                trust_remote_code = trust_remote_code,
                use_gradient_checkpointing = use_gradient_checkpointing,
                resize_model_vocab = resize_model_vocab,  # [TODO] No effect
                revision = revision,
                return_logits = False,  # Return logits
                fullgraph = True,  # No graph breaks
                use_exact_model_name = use_exact_model_name,
                offload_embedding = offload_embedding,
                float32_mixed_precision = float32_mixed_precision,
                # Pass vLLM/inference parameters
                fast_inference = fast_inference,
                gpu_memory_utilization = gpu_memory_utilization,
                float8_kv_cache = float8_kv_cache,
                random_state = random_state,
                max_lora_rank = max_lora_rank,
                disable_log_stats = disable_log_stats,
                qat_scheme = qat_scheme,
                load_in_fp8 = load_in_fp8,
                unsloth_tiled_mlp = unsloth_tiled_mlp,
                *args,
                **kwargs,
            )

        if isinstance(dtype, str) and dtype in ["float16", "bfloat16"]:
            dtype = getattr(torch, dtype)
        assert (
            dtype is None
            or dtype == torch.float16
            or dtype == torch.bfloat16
            or dtype == torch.float32
        )

        if fast_inference:
            if importlib.util.find_spec("vllm") is None:
                raise ImportError(
                    "Unsloth: Please install vLLM before enabling `fast_inference`!\n"
                    "You can do this in a terminal via `pip install vllm`"
                )
            if DEVICE_TYPE_TORCH == "cuda":
                for i in range(DEVICE_COUNT):
                    # [TODO] DGX Spark vLLM breaks
                    if "NVIDIA GB10" in str(torch.cuda.get_device_name(i)).upper():
                        print(
                            "Unsloth: DGX Spark detected - `fast_inference=True` is currently broken as of January 2026.\n"
                            "Defaulting to native Unsloth inference."
                        )
                        fast_inference = False
                        break

        # [TODO] For now fast_inference only works with fast_inference ie vLLM
        if load_in_fp8 != False:
            if not fast_inference:
                raise NotImplementedError(
                    "Unsloth: set `fast_inference = True` when doing `load_in_fp8`."
                )
        # Check if 4bit is allowed specifically for AMD
        if not ALLOW_BITSANDBYTES and not use_exact_model_name:
            if load_in_4bit or load_in_8bit or model_name.lower().endswith("-bnb-4bit"):
                print(
                    "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now."
                )
            load_in_4bit = False

        # Find FP8, BnB 4bit, other mapped names
        old_model_name = model_name
        fp8_mode = None
        if not use_exact_model_name:
            new_model_name = get_model_name(
                model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
            )
            if new_model_name is None and load_in_fp8 != False:
                fp8_mode = _get_fp8_mode_and_check_settings(
                    load_in_fp8,
                    fast_inference,
                    full_finetuning,
                    load_in_4bit,
                    load_in_8bit,
                    load_in_16bit,
                    use_exact_model_name,
                )
                model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
            else:
                assert new_model_name is not None
                model_name = new_model_name

        # Check if pre-quantized models are allowed
        # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
            ("-unsloth-bnb-4bit", "-bnb-4bit")
        ):
            model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
            model_name = model_name.lower().removesuffix("-bnb-4bit")
        # Change -BF16 to all False for 4bit, 8bit etc
        if model_name.lower().endswith("-bf16"):
            load_in_4bit = False
            load_in_8bit = False
            load_in_fp8 = False
            load_in_16bit = True

        if USE_MODELSCOPE and not os.path.exists(model_name):
            from modelscope import snapshot_download

            model_name = snapshot_download(model_name)

        # First check if it's a normal model via AutoConfig
        from huggingface_hub.utils import (
            disable_progress_bars,
            enable_progress_bars,
            are_progress_bars_disabled,
        )

        was_disabled = are_progress_bars_disabled()
        disable_progress_bars()

        autoconfig_error = None
        peft_error = None
        model_config = None
        peft_config = None
        try:
            model_config = AutoConfig.from_pretrained(
                model_name,
                token = token,
                revision = revision,
                trust_remote_code = trust_remote_code,
            )
            is_model = True
        except ImportError:
            raise
        except Exception as error:
            autoconfig_error = str(error)
            if "architecture" in autoconfig_error:
                raise ValueError(
                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                    f"Please update transformers via `pip install --upgrade transformers` and try again."
                )
            is_model = False
        try:
            peft_config = PeftConfig.from_pretrained(
                model_name,
                token = token,
                revision = revision,
                trust_remote_code = trust_remote_code,
            )
            is_peft = True
        except ImportError:
            raise
        except Exception as error:
            peft_error = str(error)
            if "architecture" in peft_error:
                raise ValueError(
                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                    f"Please update transformers via `pip install --upgrade transformers` and try again."
                )
            is_peft = False

        # Old transformers versions check
        both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32

        # Error out if both LoRA and normal model config exists.
        if both_exist:
            raise RuntimeError(
                "Unsloth: Your repo has a LoRA adapter and a base model.\n"
                "You have 2 files `config.json` and `adapter_config.json`.\n"
                "We must only allow one config file.\n"
                "Please separate the LoRA and base models to 2 repos."
            )
        model_types = get_transformers_model_type(
            peft_config if peft_config is not None else model_config,
            trust_remote_code = trust_remote_code,
        )
        if len(model_types) == 1:
            model_type = model_types[0]
        else:
            # Leave as tuple if more than one arch
            model_type = model_types

        # New transformers need to check manually.
        if SUPPORTS_LLAMA32:
            # Check if folder exists locally
            if os.path.isdir(model_name):
                exist_adapter_config = os.path.exists(
                    os.path.join(model_name, "adapter_config.json")
                )
                exist_config = os.path.exists(os.path.join(model_name, "config.json"))
                both_exist = exist_adapter_config and exist_config
            else:
                # Because HfFileSystem assumes linux paths, we need to set the path with forward slashes, even on Windows.
                files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
                files = list(os.path.split(x)[-1] for x in files)
                if (
                    sum(x == "adapter_config.json" or x == "config.json" for x in files)
                    >= 2
                ):
                    both_exist = True

        if not is_model and not is_peft:
            error = autoconfig_error if autoconfig_error is not None else peft_error
            # Old transformers version
            if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
                raise ImportError(
                    f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"
                    f"This includes Llama 3.1. The minimum required version is 4.43.2\n"
                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
                    f"to obtain the latest transformers build, then restart this session."
                )
            # Create a combined error message showing both failures
            combined_error = (
                "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n"
                f"AutoConfig error: {autoconfig_error}\n\n"
                f"PeftConfig error: {peft_error}\n\n"
            )
            raise RuntimeError(combined_error)

        # Get base model for PEFT:
        if is_peft:
            # Check base model again for PEFT
            model_name = peft_config.base_model_name_or_path
            if not use_exact_model_name:
                model_name = get_model_name(model_name, load_in_4bit)
            # Check if pre-quantized models are allowed
            # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
                ("-unsloth-bnb-4bit", "-bnb-4bit")
            ):
                model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
                model_name = model_name.lower().removesuffix("-bnb-4bit")
            # Change -BF16 to all False for 4bit, 8bit etc
            if model_name.lower().endswith("-bf16"):
                load_in_4bit = False
                load_in_8bit = False
                load_in_fp8 = False
                load_in_16bit = True

            model_config = AutoConfig.from_pretrained(
                model_name,
                token = token,
                trust_remote_code = trust_remote_code,
            )

        if not was_disabled:
            enable_progress_bars()

        if model_type == "llama":
            scaling_type = None
            if getattr(model_config, "rope_scaling", None) is not None:
                scaling_type1 = model_config.rope_scaling.get("type", None)
                scaling_type2 = model_config.rope_scaling.get("rope_type", None)
                scaling_type = (
                    scaling_type1 if scaling_type1 is not None else scaling_type2
                )

            if scaling_type == "llama3" and not SUPPORTS_LLAMA31:
                raise ImportError(
                    f"Unsloth: Your transformers version of {transformers_version} does not support Llama 3.1.\n"
                    f"The minimum required version is 4.43.2\n"
                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
                    f"to obtain the latest transformers build, then restart this session."
                )

            dispatch_model = FastLlamaModel

        elif model_type == "mistral":
            dispatch_model = FastMistralModel
        elif model_type == "gemma":
            if not SUPPORTS_GEMMA:
                raise ImportError(
                    f"Unsloth: Your transformers version of {transformers_version} does not support Gemma.\n"
                    f"The minimum required version is 4.38.\n"
                    f'Try `pip install --upgrade "transformers>=4.38"`\n'
                    f"to obtain the latest transformers build, then restart this session."
                )
            dispatch_model = FastGemmaModel
        elif model_type == "gemma2":
            if not SUPPORTS_GEMMA2:
                raise ImportError(
                    f"Unsloth: Your transformers version of {transformers_version} does not support Gemma2.\n"
                    f"The minimum required version is 4.42.3.\n"
                    f'Try `pip install --upgrade "transformers>=4.42.3"`\n'
                    f"to obtain the latest transformers build, then restart this session."
                )
            # Also check for softcapping support in flash-attn which is faster!
            if is_bfloat16_supported() and not HAS_FLASH_ATTENTION:
                print(
                    "Unsloth: If you want to finetune Gemma 2, install flash-attn to make it faster!\n"
                    "To install flash-attn, do the below:\n"
                    '\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
                )
            elif HAS_FLASH_ATTENTION and not HAS_FLASH_ATTENTION_SOFTCAPPING:
                print(
                    "Unsloth: If you want to finetune Gemma 2, upgrade flash-attn to version 2.6.3 or higher!\n"
                    "Newer versions support faster and less memory usage kernels for Gemma 2's attention softcapping!\n"
                    "To update flash-attn, do the below:\n"
                    '\npip install --no-deps --upgrade "flash-attn>=2.6.3"'
                )

            dispatch_model = FastGemma2Model
        elif model_type == "qwen2":
            dispatch_model = FastQwen2Model
        elif model_type == "qwen3":  # or model_type == "qwen3_moe":
            if not SUPPORTS_QWEN3 or not SUPPORTS_QWEN3_MOE:
                raise ImportError(
                    f"Unsloth: Your transformers version of {transformers_version} does not support Qwen3.\n"
                    f"The minimum required version is 4.50.3.\n"
                    f'Try `pip install --upgrade "transformers>=4.50.3"`\n'
                    f"to obtain the latest transformers build, then restart this session."
                )
            dispatch_model = (
                FastQwen3Model if model_type == "qwen3" else FastQwen3MoeModel
            )
        # elif model_type == "falcon_h1":
        #     dispatch_model = FastFalconH1Model
        #     if not SUPPORTS_FALCON_H1:
        #         raise ImportError(
        #             f"Unsloth: Your transformers version of {transformers_version} does not support FalconH1.\n"\
        #             f"The minimum required version is 4.50.3.\n"\
        #             f'Try `pip install --upgrade "transformers>=4.50.3"`\n'\
        #             f"to obtain the latest transformers build, then restart this session."\
        #         )
        # Temporary disable optimized Cohere until errors match
        # elif model_type == "cohere":
        #     dispatch_model = FastCohereModel
        # Temporary disable optimized Granite until errors match
        # elif model_type == "granite":
        #     dispatch_model = FastGraniteModel
        else:
            return FastModel.from_pretrained(
                model_name = old_model_name,
                max_seq_length = max_seq_length,
                dtype = dtype,
                load_in_4bit = load_in_4bit,
                load_in_8bit = load_in_8bit,
                load_in_16bit = load_in_16bit,
                full_finetuning = full_finetuning,
                token = token,
                device_map = device_map,
                rope_scaling = rope_scaling,  # [TODO] No effect
                fix_tokenizer = fix_tokenizer,  # [TODO] No effect
                trust_remote_code = trust_remote_code,
                use_gradient_checkpointing = use_gradient_checkpointing,
                resize_model_vocab = resize_model_vocab,  # [TODO] No effect
                revision = revision,
                return_logits = False,  # Return logits
                fullgraph = True,  # No graph breaks
                use_exact_model_name = use_exact_model_name,
                offload_embedding = offload_embedding,
                float32_mixed_precision = float32_mixed_precision,
                # Pass vLLM/inference parameters
                fast_inference = fast_inference,
                gpu_memory_utilization = gpu_memory_utilization,
                float8_kv_cache = float8_kv_cache,
                random_state = random_state,
                max_lora_rank = max_lora_rank,
                disable_log_stats = disable_log_stats,
                qat_scheme = qat_scheme,
                load_in_fp8 = load_in_fp8,
                unsloth_tiled_mlp = unsloth_tiled_mlp,
                *args,
                **kwargs,
            )

        if use_gradient_checkpointing == "unsloth":
            patch_unsloth_smart_gradient_checkpointing(dtype = dtype)

        # Check if this is local model since the tokenizer gets overwritten
        if (
            os.path.exists(os.path.join(old_model_name, "tokenizer_config.json"))
            and os.path.exists(os.path.join(old_model_name, "tokenizer.json"))
            and os.path.exists(os.path.join(old_model_name, "special_tokens_map.json"))
        ):
            tokenizer_name = old_model_name
        else:
            tokenizer_name = kwargs.pop("tokenizer_name", None)

        if fast_inference:
            fast_inference, model_name = fast_inference_setup(model_name, model_config)

        load_in_4bit_kwargs = load_in_4bit
        load_in_8bit_kwargs = load_in_8bit
        if quantization_config is not None and not fast_inference:
            load_in_4bit_kwargs = False
            load_in_8bit_kwargs = False

        model, tokenizer = dispatch_model.from_pretrained(
            model_name = model_name,
            max_seq_length = max_seq_length,
            dtype = _get_dtype(dtype),
            load_in_4bit = load_in_4bit_kwargs,
            token = token,
            device_map = device_map,
            rope_scaling = rope_scaling,
            fix_tokenizer = fix_tokenizer,
            model_patcher = dispatch_model,
            tokenizer_name = tokenizer_name,
            trust_remote_code = trust_remote_code,
            revision = revision if not is_peft else None,
            fast_inference = fast_inference,
            gpu_memory_utilization = gpu_memory_utilization,
            float8_kv_cache = float8_kv_cache,
            random_state = random_state,
            max_lora_rank = max_lora_rank,
            disable_log_stats = disable_log_stats,
            *args,
            **kwargs,
        )

        if resize_model_vocab is not None:
            model.resize_token_embeddings(resize_model_vocab)

        # In case the model supports tagging, add the unsloth tag.
        if hasattr(model, "add_model_tags"):
            model.add_model_tags(
                [
                    "unsloth",
                ]
            )
        if hasattr(tokenizer, "add_model_tags"):
            tokenizer.add_model_tags(
                [
                    "unsloth",
                ]
            )

        if load_in_4bit:
            # Fix up bitsandbytes config, but respect user-provided quantization_config
            if quantization_config is None:
                compute_dtype = dtype_from_config(model.config)
                quantization_config = {
                    # Sometimes compute_dtype is not a string!!
                    "bnb_4bit_compute_dtype": compute_dtype,
                    "bnb_4bit_quant_type": "nf4",
                    "bnb_4bit_use_double_quant": True,
                    "llm_int8_enable_fp32_cpu_offload": False,
                    "llm_int8_has_fp16_weight": False,
                    "llm_int8_skip_modules": None,
                    "llm_int8_threshold": 6.0,
                    "load_in_4bit": True,
                    "load_in_8bit": False,
                    "quant_method": "bitsandbytes",
                }
                model.config.update({"quantization_config": quantization_config})
            else:
                if hasattr(quantization_config, "to_dict"):
                    model.config.update(
                        {"quantization_config": quantization_config.to_dict()}
                    )
                elif isinstance(quantization_config, dict):
                    model.config.update({"quantization_config": quantization_config})

        if load_in_fp8 != False:
            _tag_model_with_fp8_torchao_config(model, fp8_mode)

        if is_peft:
            # From https://github.com/huggingface/peft/issues/184
            # Now add PEFT adapters
            model = PeftModel.from_pretrained(
                model,
                old_model_name,
                token = token,
                revision = revision,
                is_trainable = True,
                trust_remote_code = trust_remote_code,
            )
            # Patch it as well!
            model = dispatch_model.patch_peft_model(model, use_gradient_checkpointing)

        # Patch Tiled MLP
        # to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
        patch_tiled_mlp_choice = os.environ.get(
            "UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
        )
        if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
            patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)

        return model, tokenizer


from ..kernels import (
    patch_loss_functions,
    post_patch_loss_function,
)
from .vision import FastBaseModel
from transformers import (
    AutoModelForCausalLM,
)

try:
    from transformers import AutoModelForImageTextToText

    AutoModelForVision2Seq = AutoModelForImageTextToText
except:
    from transformers import AutoModelForVision2Seq


class FastModel(FastBaseModel):
    @staticmethod
    def _prepare_for_qat(model, qat_scheme):
        model = _prepare_model_for_qat(model, qat_scheme)
        return model

    @staticmethod
    def from_pretrained(
        model_name = "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit",
        max_seq_length = 2048,
        dtype = None,
        load_in_4bit = True,  # 4bit QLoRA
        load_in_8bit = False,  # 8bit  LoRA
        load_in_16bit = False,  # 16bit LoRA
        full_finetuning = False,
        token = None,
        device_map = "sequential",
        rope_scaling = None,  # [TODO] No effect
        fix_tokenizer = True,  # [TODO] No effect
        trust_remote_code = False,
        use_gradient_checkpointing = "unsloth",
        resize_model_vocab = None,  # [TODO] No effect
        revision = None,
        return_logits = False,  # Return logits
        fullgraph = True,  # No graph breaks
        use_exact_model_name = False,
        auto_model = None,
        whisper_language = None,
        whisper_task = None,
        unsloth_force_compile = False,
        offload_embedding = False,
        float32_mixed_precision = None,  # Forces float32 mixed precision
        # Add the missing vLLM/inference parameters
        fast_inference = False,  # uses vLLM
        gpu_memory_utilization = 0.5,
        float8_kv_cache = False,
        random_state = 3407,
        max_lora_rank = 64,
        disable_log_stats = True,
        qat_scheme = None,
        load_in_fp8 = False,  # fp8 LoRA (True, False, 'block')
        unsloth_tiled_mlp = False,
        *args,
        **kwargs,
    ):
        # Respect user-provided quantization_config (e.g. BitsAndBytesConfig)
        quantization_config = kwargs.get("quantization_config", None)
        if quantization_config is not None:
            if isinstance(quantization_config, dict):
                q_load_in_4bit = quantization_config.get("load_in_4bit", False)
                q_load_in_8bit = quantization_config.get("load_in_8bit", False)
            else:
                q_load_in_4bit = getattr(quantization_config, "load_in_4bit", False)
                q_load_in_8bit = getattr(quantization_config, "load_in_8bit", False)
            if q_load_in_4bit:
                load_in_4bit = True
                load_in_8bit = False
            if q_load_in_8bit:
                load_in_8bit = True
                load_in_4bit = False

        # Login to allow private models
        token = hf_login(token)
        if whisper_language is not None:
            assert type(whisper_language) is str
        if whisper_task is not None:
            assert type(whisper_task) is str
        # Align dtype with bnb_4bit_compute_dtype if provided and dtype is unset.
        if dtype is None and quantization_config is not None:
            bnb_compute_dtype = None
            if isinstance(quantization_config, dict):
                if quantization_config.get("load_in_4bit", False):
                    bnb_compute_dtype = quantization_config.get(
                        "bnb_4bit_compute_dtype", None
                    )
            else:
                if getattr(quantization_config, "load_in_4bit", False):
                    bnb_compute_dtype = getattr(
                        quantization_config, "bnb_4bit_compute_dtype", None
                    )
            if isinstance(bnb_compute_dtype, str):
                bnb_compute_dtype = getattr(torch, bnb_compute_dtype, None)
            if isinstance(bnb_compute_dtype, torch.dtype):
                dtype = bnb_compute_dtype
        SUPPORTS_BFLOAT16 = is_bfloat16_supported()
        if dtype is None:
            dtype = torch.float16 if not SUPPORTS_BFLOAT16 else torch.bfloat16
        elif dtype == torch.bfloat16 and not SUPPORTS_BFLOAT16:
            logger.warning_once(
                "Device does not support bfloat16. Will change to float16."
            )
            dtype = torch.float16
        assert dtype in (torch.float16, torch.bfloat16, torch.float32)
        assert load_in_fp8 in (True, False, "block")

        patch_compiled_autograd()
        patch_compiling_bitsandbytes()

        if full_finetuning and (load_in_4bit or load_in_8bit):
            print(
                "Unsloth: You selected full finetuning support, but 4bit / 8bit is enabled - disabling LoRA / QLoRA."
            )
            load_in_4bit = False
            load_in_8bit = False
            load_in_fp8 = False
            load_in_16bit = False

        if (
            int(load_in_4bit)
            + int(load_in_8bit)
            + int(load_in_16bit)
            + int(load_in_fp8 != False)
            >= 2
        ):
            raise RuntimeError(
                "Unsloth: Can only load in 4bit or 8bit or 16bit, not a combination!\n"
                "Also, we by default set `load_in_4bit = True`.\n"
                "If you want 8bit finetuning, set both `load_in_4bit = False` and `load_in_8bit = True`\n"
                "If you want 16bit LoRA finetuning, set `load_in_16bit = True`"
            )

        if qat_scheme is not None and not full_finetuning:
            raise ValueError(
                "Specifying `qat_scheme` in `FastLanguageModel.from_pretrained(...)` is only "
                "compatible with `full_finetuning=True`. If you wish to use QAT with LoRA, "
                "please pass in `qat_scheme` in `FastLanguageModel.get_peft_model(...)` instead."
            )
        if qat_scheme == "phone-deployment":
            qat_scheme = "int8-int4"
        # Check if 4bit is allowed specifically for AMD
        if not ALLOW_BITSANDBYTES and not use_exact_model_name:
            if load_in_4bit or load_in_8bit or model_name.lower().endswith("-bnb-4bit"):
                print(
                    "Unsloth: AMD currently is not stable with 4bit bitsandbytes. Disabling for now."
                )
            load_in_4bit = False

        if fast_inference:
            if importlib.util.find_spec("vllm") is None:
                raise ImportError(
                    "Unsloth: Please install vLLM before enabling `fast_inference`!\n"
                    "You can do this in a terminal via `pip install vllm`"
                )
            if DEVICE_TYPE_TORCH == "cuda":
                for i in range(DEVICE_COUNT):
                    # [TODO] DGX Spark vLLM breaks
                    if "NVIDIA GB10" in str(torch.cuda.get_device_name(i)).upper():
                        print(
                            "Unsloth: DGX Spark detected - `fast_inference=True` is currently broken as of January 2026.\n"
                            "Defaulting to native Unsloth inference."
                        )
                        fast_inference = False
                        break

        # [TODO] For now fast_inference only works with fast_inference ie vLLM
        if load_in_fp8 != False:
            if not fast_inference:
                raise NotImplementedError(
                    "Unsloth: set `fast_inference = True` when doing `load_in_fp8`."
                )

        # Find FP8, BnB 4bit, other mapped names
        old_model_name = model_name
        fp8_mode = None
        if not use_exact_model_name:
            new_model_name = get_model_name(
                model_name, load_in_4bit = load_in_4bit, load_in_fp8 = load_in_fp8
            )
            if new_model_name is None and load_in_fp8 != False:
                fp8_mode = _get_fp8_mode_and_check_settings(
                    load_in_fp8,
                    fast_inference,
                    full_finetuning,
                    load_in_4bit,
                    load_in_8bit,
                    load_in_16bit,
                    use_exact_model_name,
                )
                model_name = _offline_quantize_to_fp8(model_name, fp8_mode)
            else:
                assert new_model_name is not None
                model_name = new_model_name

        # Check if pre-quantized models are allowed
        # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
        if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
            ("-unsloth-bnb-4bit", "-bnb-4bit")
        ):
            model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
            model_name = model_name.lower().removesuffix("-bnb-4bit")
        # Change -BF16 to all False for 4bit, 8bit etc
        if model_name.lower().endswith("-bf16"):
            load_in_4bit = False
            load_in_8bit = False
            load_in_fp8 = False
            load_in_16bit = True

        # Check modelscope
        if USE_MODELSCOPE and not os.path.exists(model_name):
            from modelscope import snapshot_download

            model_name = snapshot_download(model_name)

        # First check if it's a normal model via AutoConfig
        from huggingface_hub.utils import (
            disable_progress_bars,
            enable_progress_bars,
            are_progress_bars_disabled,
        )

        was_disabled = are_progress_bars_disabled()
        disable_progress_bars()

        autoconfig_error = None
        peft_error = None
        model_config = None
        peft_config = None
        try:
            model_config = AutoConfig.from_pretrained(
                model_name,
                token = token,
                revision = revision,
                trust_remote_code = trust_remote_code,
            )
            is_model = True
        except ImportError:
            raise
        except Exception as error:
            autoconfig_error = str(error)
            if "architecture" in autoconfig_error:
                raise ValueError(
                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                    f"Please update transformers via `pip install --upgrade transformers` and try again."
                )
            is_model = False
        try:
            peft_config = PeftConfig.from_pretrained(
                model_name,
                token = token,
                revision = revision,
                trust_remote_code = trust_remote_code,
            )
            is_peft = True
        except ImportError:
            raise
        except Exception as error:
            peft_error = str(error)
            if "architecture" in peft_error:
                raise ValueError(
                    f"`{model_name}` is not supported yet in `transformers=={transformers_version}`.\n"
                    f"Please update transformers via `pip install --upgrade transformers` and try again."
                )
            is_peft = False
        # Old transformers versions check
        both_exist = (is_model and is_peft) and not SUPPORTS_LLAMA32
        # Error out if both LoRA and normal model config exists.
        if both_exist:
            raise RuntimeError(
                "Unsloth: Your repo has a LoRA adapter and a base model.\n"
                "You have 2 files `config.json` and `adapter_config.json`.\n"
                "We must only allow one config file.\n"
                "Please separate the LoRA and base models to 2 repos."
            )
        model_types = get_transformers_model_type(
            peft_config if peft_config is not None else model_config,
            trust_remote_code = trust_remote_code,
        )
        model_types_all = ",".join(model_types) + ","

        # Save model types and loading method
        lowered_model_name = model_name.lower()
        string = os.environ.get("UNSLOTH_MODEL_NAME", "") + model_types_all
        if load_in_4bit:
            string += "_load_in_4bit_"
        if load_in_8bit:
            string += "_load_in_8bit_"
        if load_in_16bit:
            string += "_load_in_16bit_"
        if load_in_fp8:
            string += "load_in_fp8"
        os.environ["UNSLOTH_MODEL_NAME"] = string

        # Check versions
        LATEST = "\nPlease use transformers via `pip install --no-deps git+https://github.com/huggingface/transformers.git`"
        NIGHTLY = '\nPlease use nightly transformers via pip install --upgrade "transformers>=4.49.0"`'
        # Pixtral
        if "pixtral" in model_types_all and transformers_version < Version("4.49.0"):
            raise RuntimeError(
                "Unsloth: Pixtral only works on transformers >= 4.49.0." + LATEST
            )
        # Qwen 2.5
        elif "qwen2_5" in model_types_all and transformers_version < Version("4.49.0"):
            raise RuntimeError(
                "Unsloth: Qwen 2.5 only works on transformers >= 4.49.0." + LATEST
            )
        # Gemma 3N must be before Gemma 3
        elif "gemma3n" in model_types_all:
            if transformers_version < Version("4.53.0"):
                raise RuntimeError(
                    "Unsloth: Gemma 3N only works on transformers >= 4.53.0" + LATEST
                )
            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
                "float16;torch.float16;torch.float16;"
                "if name.endswith('norm'): "
                "module._pre_set_compute_dtype = torch.float32\n"
                ";"
                "from unsloth_zoo.temporary_patches.gemma3n import patch_Gemma3nConv_Embed_forwards; patch_Gemma3nConv_Embed_forwards()"
            )
            # Set norms to float32 since anyways they get upcasted to float32
            # common in both gemma-3 and gemma-3n
            os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
        # Gemma 3
        elif "gemma3" in model_types_all:
            if transformers_version < Version("4.50.0.dev0"):
                raise RuntimeError(
                    "Unsloth: Gemma 3 only works on transformers >= 4.50.0." + NIGHTLY
                )
            # Set norms to float32 since anyways they get upcasted to float32
            # common in both gemma-3 and gemma-3n
            os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
        # Cohere
        elif "cohere2" in model_types_all and transformers_version < Version(
            "4.50.0.dev0"
        ):
            raise RuntimeError(
                "Unsloth: Cohere's Command model only works on transformers >= 4.50.0."
                + NIGHTLY
            )
        # Sesame
        elif "csm" in model_types_all:
            os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"  # Inference is too slow
            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"  # Sesame fails
            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
                "all;torch.float32;torch.float16;"
                "if name.endswith(('_proj', 'fc1', 'fc2', 'codebook', 'head')): module.to(torch.float16)"
                ";"
            )
        # Granite 4
        elif "granitemoehybrid" in model_types_all:
            # Granite-4 rms norms are stored as 16 bit, but we upcast
            os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
        # Olmo 2
        elif "olmo2" in model_types_all and transformers_version < Version(
            "4.50.0.dev0"
        ):
            raise RuntimeError(
                "Unsloth: OLMo-2 only works on transformers >= 4.50.0." + NIGHTLY
            )
        elif "falcon_h1" in model_types_all:
            # Falcon must use float32 Triton ie TRITON_F32_DEFAULT = 'ieee'
            # since Mamba kernels error out on using lower precision
            os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
                "float16;torch.float32;torch.float16;"
                "if name.endswith(('q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj', 'head')): module.to(torch.float16)"
                ";"
                "os.environ['TRITON_F32_DEFAULT'] = 'ieee'"
            )
        elif "gpt_oss" in model_types_all:
            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
            if not load_in_4bit:
                # Only upcast MoE biases for MXFP4, not BnB
                # Set norms to float32 since anyways they get upcasted to float32
                os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
                    "all;None;None;"
                    "x = 'gate_up_proj_bias'\n"
                    "if hasattr(module, x): "
                    "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"
                    ""
                    "x = 'down_proj_bias'\n"
                    "if hasattr(module, x): "
                    "setattr(module, x, torch.nn.Parameter(getattr(module, x).to(torch.float32)) if isinstance(getattr(module, x), torch.nn.Parameter) else getattr(module, x).to(torch.float32))\n"
                    ""
                    ";"
                )
            else:
                # Set down projection compute dtype to be float32 for float16 machines
                # Set norms to float32 since anyways they get upcasted to float32
                os.environ["UNSLOTH_FORCE_CUSTOM_DTYPE"] = (
                    "torch.float16;torch.bfloat16;torch.float16;"
                    "if ('down_projs' in name) and hasattr(module, 'weight') and "
                    "torch.amax(dequantize_module_weight(module)) >= 0:"
                    "module._pre_set_compute_dtype = torch.float32\n"
                    ""
                    "if ('mlp.router' in name) and hasattr(module, 'weight'):"
                    "module._pre_set_compute_dtype = torch.float32\n"
                    ";"
                )
            # Set norms to float32 since anyways they get upcasted to float32
            os.environ["UNSLOTH_HIGH_PRECISION_LAYERNORM"] = "1"
        else:
            for check_model_name in DISABLE_COMPILE_MODEL_NAMES:
                if check_model_name in lowered_model_name:
                    os.environ["UNSLOTH_COMPILE_DISABLE"] = "partial"
                    os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"
                    if transformers_version < Version("4.50.0.dev0"):
                        raise RuntimeError(
                            f"Unsloth: {check_model_name} only works on transformers >= 4.50.0."
                            + NIGHTLY
                        )
                    break

        if auto_model is not None:
            # All other models need to disable static cache
            os.environ["UNSLOTH_DISABLE_STATIC_GENERATION"] = "1"

        # New transformers need to check manually.
        if SUPPORTS_LLAMA32:
            # Check if folder exists locally
            if os.path.isdir(model_name):
                exist_adapter_config = os.path.exists(
                    os.path.join(model_name, "adapter_config.json")
                )
                exist_config = os.path.exists(os.path.join(model_name, "config.json"))
                both_exist = exist_adapter_config and exist_config
            else:
                files = HfFileSystem(token = token).glob(f"{model_name}/*.json")
                files = list(os.path.split(x)[-1] for x in files)
                if (
                    sum(x == "adapter_config.json" or x == "config.json" for x in files)
                    >= 2
                ):
                    both_exist = True

        if not is_model and not is_peft:
            error = autoconfig_error if autoconfig_error is not None else peft_error
            # Old transformers version
            if "rope_scaling" in error.lower() and not SUPPORTS_LLAMA31:
                raise ImportError(
                    f"Unsloth: Your transformers version of {transformers_version} does not support new RoPE scaling methods.\n"
                    f"This includes Llama 3.1. The minimum required version is 4.43.2\n"
                    f'Try `pip install --upgrade "transformers>=4.43.2"`\n'
                    f"to obtain the latest transformers build, then restart this session."
                )
            # Create a combined error message showing both failures
            combined_error = (
                "Unsloth: Failed to load model. Both AutoConfig and PeftConfig loading failed.\n\n"
                f"AutoConfig error: {autoconfig_error}\n\n"
                f"PeftConfig error: {peft_error}\n\n"
            )
            raise RuntimeError(combined_error)

        # Get base model for PEFT:
        if is_peft:
            # Check base model again for PEFT
            model_name = peft_config.base_model_name_or_path
            if not use_exact_model_name:
                model_name = get_model_name(model_name, load_in_4bit)
            # Check if pre-quantized models are allowed
            # For eg AMD Instinct GPUs need blocksize = 128, but our pre-quants are blocksize = 64
            if not ALLOW_PREQUANTIZED_MODELS and model_name.lower().endswith(
                ("-unsloth-bnb-4bit", "-bnb-4bit")
            ):
                model_name = model_name.lower().removesuffix("-unsloth-bnb-4bit")
                model_name = model_name.lower().removesuffix("-bnb-4bit")
            # Change -BF16 to all False for 4bit, 8bit etc
            if model_name.lower().endswith("-bf16"):
                load_in_4bit = False
                load_in_8bit = False
                load_in_fp8 = False
                load_in_16bit = True

            model_config = AutoConfig.from_pretrained(
                model_name,
                token = token,
                trust_remote_code = trust_remote_code,
            )

        if not was_disabled:
            enable_progress_bars()

        do_logging = os.environ.get("UNSLOTH_ENABLE_LOGGING", "0") == "1"
        if do_logging:
            redirector = contextlib.nullcontext()
        else:
            redirector = contextlib.redirect_stdout(open(os.devnull, "w"))

        model_types = ["siglip"] + model_types
        # Set forced float32 env flag
        os.environ["UNSLOTH_FORCE_FLOAT32"] = "0"
        do_forced_float32 = False
        for model_type_arch in model_types:
            if model_type_arch != "siglip":
                break
        global FORCE_FLOAT32
        for disable_name in FORCE_FLOAT32:
            # add comma to model_types_all matching in case of exact match for end
            if (
                disable_name.lower()
                == model_type_arch.lower().replace("-", "").replace("_", "")
                or disable_name.lower() in model_types_all
            ) and ((dtype == torch.float16) or not SUPPORTS_BFLOAT16):
                os.environ["UNSLOTH_FORCE_FLOAT32"] = "1"
                dtype = torch.bfloat16  # Change to bfloat16 loading
                break
        # Patch gradient checkpointing
        if use_gradient_checkpointing == "unsloth":
            patch_unsloth_smart_gradient_checkpointing(dtype = dtype)
        with redirector:
            patch_loss_functions(torch_compile = False)
            model_types, supports_sdpa = unsloth_compile_transformers(
                dtype = dtype,
                model_name = model_name,
                model_types = model_types,
                token = token,
                sdpa_dynamic_mask = True,
                sdpa_bool_masks = True,
                sdpa_gqa_replace = True,
                sdpa_dynamic_compile = True,
                compile_attention = True,
                disable_causal_masks = True,
                compile_torch_modules = True,
                compile_custom_modules = True,
                compile_function_calls = True,
                fuse_lm_head = True,
                gradient_checkpointing = True,
                manual_replacements = True,
                fast_lora_forwards = True,
                fast_residual_stream = False,
                accurate_accumulation = True,
                epilogue_fusion = True,
                max_autotune = False,
                shape_padding = True,
                cudagraphs = False,
                debug = False,
                fullgraph = fullgraph,
                import_from_cache = False,
                disable = False,
                return_logits = return_logits,
                trust_remote_code = trust_remote_code,
                unsloth_force_compile = unsloth_force_compile,
            )
        # Fix SDPA issues
        for model_type in DISABLE_SDPA_MODEL_NAMES:
            if model_type in model_types_all:
                supports_sdpa = False

        # Check if this is local model since the tokenizer gets overwritten
        if (
            os.path.exists(os.path.join(old_model_name, "tokenizer_config.json"))
            and os.path.exists(os.path.join(old_model_name, "tokenizer.json"))
            and os.path.exists(os.path.join(old_model_name, "special_tokens_map.json"))
        ):
            tokenizer_name = old_model_name
        else:
            tokenizer_name = kwargs.pop("tokenizer_name", None)

        # Check if VLM
        architectures = getattr(model_config, "architectures", None)
        if architectures is None:
            architectures = []
        is_vlm = any(x.endswith("ForConditionalGeneration") for x in architectures)
        is_vlm = is_vlm or hasattr(model_config, "vision_config")
        if auto_model is None:
            auto_model = AutoModelForVision2Seq if is_vlm else AutoModelForCausalLM

        load_in_4bit_kwargs = load_in_4bit
        load_in_8bit_kwargs = load_in_8bit
        if quantization_config is not None and not fast_inference:
            load_in_4bit_kwargs = False
            load_in_8bit_kwargs = False

        model, tokenizer = FastBaseModel.from_pretrained(
            model_name = model_name,
            max_seq_length = max_seq_length,
            dtype = _get_dtype(dtype),
            load_in_4bit = load_in_4bit_kwargs,
            load_in_8bit = load_in_8bit_kwargs,
            load_in_16bit = load_in_16bit,
            full_finetuning = full_finetuning,
            token = token,
            device_map = device_map,
            trust_remote_code = trust_remote_code,
            revision = revision if not is_peft else None,
            model_types = model_types,
            tokenizer_name = tokenizer_name,
            auto_model = auto_model,
            use_gradient_checkpointing = use_gradient_checkpointing,
            supports_sdpa = supports_sdpa,
            whisper_language = whisper_language,
            whisper_task = whisper_task,
            auto_config = model_config,
            offload_embedding = offload_embedding,
            float32_mixed_precision = float32_mixed_precision,
            # Pass vLLM/inference parameters
            fast_inference = fast_inference,
            gpu_memory_utilization = gpu_memory_utilization,
            float8_kv_cache = float8_kv_cache,
            random_state = random_state,
            max_lora_rank = max_lora_rank,
            disable_log_stats = disable_log_stats,
            *args,
            **kwargs,
        )

        if resize_model_vocab is not None:
            model.resize_token_embeddings(resize_model_vocab)

        # In case the model supports tagging, add the unsloth tag.
        if hasattr(model, "add_model_tags"):
            model.add_model_tags(
                [
                    "unsloth",
                ]
            )
        if hasattr(tokenizer, "add_model_tags"):
            tokenizer.add_model_tags(
                [
                    "unsloth",
                ]
            )

        if load_in_4bit:
            # Fix up bitsandbytes config, but respect user-provided quantization_config
            if quantization_config is None:
                compute_dtype = dtype_from_config(model.config)
                quantization_config = {
                    # Sometimes compute_dtype is not a string!!
                    "bnb_4bit_compute_dtype": compute_dtype,
                    "bnb_4bit_quant_type": "nf4",
                    "bnb_4bit_use_double_quant": True,
                    "llm_int8_enable_fp32_cpu_offload": False,
                    "llm_int8_has_fp16_weight": False,
                    "llm_int8_skip_modules": None,
                    "llm_int8_threshold": 6.0,
                    "load_in_4bit": True,
                    "load_in_8bit": False,
                    "quant_method": "bitsandbytes",
                }
                model.config.update({"quantization_config": quantization_config})
            else:
                if hasattr(quantization_config, "to_dict"):
                    model.config.update(
                        {"quantization_config": quantization_config.to_dict()}
                    )
                elif isinstance(quantization_config, dict):
                    model.config.update({"quantization_config": quantization_config})

        if load_in_fp8 != False:
            _tag_model_with_fp8_torchao_config(model, fp8_mode)

        if is_peft:
            # From https://github.com/huggingface/peft/issues/184
            # Now add PEFT adapters
            model = PeftModel.from_pretrained(
                model,
                old_model_name,
                token = token,
                revision = revision,
                is_trainable = True,
                trust_remote_code = trust_remote_code,
            )
            # Patch it as well!
            model = FastBaseModel.post_patch_model(
                model, use_gradient_checkpointing, trust_remote_code = trust_remote_code
            )

        # Apply QAT if specified
        if qat_scheme is not None:
            print("Unsloth: Applying QAT to mitigate quantization degradation")
            model = FastModel._prepare_for_qat(model, qat_scheme)

        # Patch Tiled MLP
        # to turn on set UNSLOTH_TILED_MLP to "arctic", "target", or "target:{GB}""
        patch_tiled_mlp_choice = os.environ.get(
            "UNSLOTH_TILED_MLP", "arctic" if unsloth_tiled_mlp else "0"
        )
        if patch_tiled_mlp_choice != "0" or unsloth_tiled_mlp:
            patch_tiled_mlp(model, patch_options_str = patch_tiled_mlp_choice)

        return model, tokenizer


class FastVisionModel(FastModel):
    pass


class FastTextModel(FastModel):
    pass
