# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass, field
from itertools import chain
from typing import Optional

from datasets import load_dataset
from huggingface_hub import ModelCard
from transformers import HfArgumentParser


@dataclass
class ScriptArguments:
    r"""
    Arguments for the script.

    Args:
        push_to_hub (`bool`, *optional*, defaults to `False`):
            Whether to push the dataset to the Hugging Face Hub.
        repo_id (`str`, *optional*, defaults to `"trl-lib/math_shepherd"`):
            Hugging Face repository ID to push the dataset to.
        dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
            Number of workers to use for dataset processing.
    """

    push_to_hub: bool = field(
        default=False,
        metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
    )
    repo_id: str = field(
        default="trl-lib/math_shepherd",
        metadata={"help": "Hugging Face repository ID to push the dataset to."},
    )
    dataset_num_proc: Optional[int] = field(
        default=None,
        metadata={"help": "Number of workers to use for dataset processing."},
    )


def process_example(example):
    # Replace "ки" with "ⶻ" so that the size of the "input" matches the size of the "label"
    inputs = example["input"].replace("ки", "ⶻ")

    # Find the indices of the "ⶻ" characters (that should match with the indexes of the "+" or "-" in the label)
    indexes = [m.start() for m in re.finditer("ⶻ", inputs)]

    # Sanity that all indexes are either "+" or "-"
    assert all(example["label"][idx] in ["+", "-"] for idx in indexes)

    # Get the labels
    labels = [example["label"][idx] == "+" for idx in indexes]

    # Split the inputs into steps (caution, the first step is missing here, it is the prompt)
    steps = [inputs[i:j] for i, j in zip(chain([0], indexes), chain(indexes, [None]))]

    # Remove the last step (single ⶻ)
    steps = steps[:-1]

    # Get the prompt (first part) and completions (rest)
    prompt = steps[0]
    completions = steps[1:]

    # Remove the heading "ⶻ" and the final whitespace from the completions
    assert all(completion.startswith("ⶻ") for completion in completions)
    completions = [completion[1:].strip() for completion in completions]

    # At this point, we need to retrieve the first step from the prompt.
    # First, we handle particular cases (annotation error) where we have a first label before the end of the prompt.
    if prompt.startswith(
        (
            "Mr. Rocky",
            "Parker",
            "What is the smallest positive",
            " The Myth",
            "Let $\\mathbf{a}$",
            "Find the arithmetic",
            "Determine an ordered pair",
            "Determine the ordered pair",
            "At the Quill and Scroll stationery",
            "Round to the nearest",
            r"Calculate $\sqrt{10p}",
            r"Simplify $\sqrt{28x}",
        )
    ):
        # Some spotted datasets errors where there is an annotation in the prompt: we remove it
        labels = labels[1:]

    # Then we handle the general case: we get the first step from the prompt by looking for "Step 1:" or "step 1:" or
    # (less common) "?".
    elif "Step 1:" in prompt:
        prompt, first_step = prompt.split("Step 1:")
        first_step = "Step 1:" + first_step
        completions = [first_step.strip()] + completions
    elif "step 1:" in prompt:
        prompt, first_step = prompt.split("step 1:")
        first_step = "step 1:" + first_step
        completions = [first_step.strip()] + completions
    elif "?" in prompt:
        prompt, first_step = prompt.split("?")
        prompt = prompt + "?"
        completions = [first_step.strip()] + completions
    else:
        raise ValueError(f"Prompt can't be processed: {prompt}")

    # Strip the prompt
    prompt = prompt.strip()

    # Sanity check that the length of the completions is the same as the length of the labels
    assert len(completions) == len(labels)

    return {"prompt": prompt, "completions": completions, "labels": labels}


model_card = ModelCard("""
---
tags: [trl]
---

# Math-Shepherd Dataset

## Summary

The Math-Shepherd dataset is a processed version of [Math-Shepherd dataset](peiyi9979/Math-Shepherd), designed to train models using the [TRL library](https://github.com/huggingface/trl) for stepwise supervision tasks. It provides step-by-step solutions to mathematical problems, enabling models to learn and verify each step of a solution, thereby enhancing their reasoning capabilities.

## Data Structure

- **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard)
- **Type**: [Stepwise supervision](https://huggingface.co/docs/trl/main/dataset_formats#stepwise-supervision)

Columns:
- `"prompt"`: The problem statement.
- `"completions"`: A list of reasoning steps generated to solve the problem.
- `"labels"`: A list of booleans or floats indicating the correctness of each corresponding reasoning step.

This structure allows models to learn the correctness of each step in a solution, facilitating improved reasoning and problem-solving abilities.

## Generation script

The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/math_shepherd.py).
""")

if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArguments)
    script_args = parser.parse_args_into_dataclasses()[0]

    dataset = load_dataset("peiyi9979/Math-Shepherd", split="train")

    dataset = dataset.map(
        process_example,
        remove_columns=["input", "label", "task"],
        num_proc=script_args.dataset_num_proc,
    )
    dataset = dataset.train_test_split(test_size=0.05, seed=42)

    if script_args.push_to_hub:
        dataset.push_to_hub(script_args.repo_id)
        model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
