# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Part of TF-Keras training engine related to Python generators of array data.
"""

import functools
import math

import numpy as np
import tensorflow.compat.v2 as tf

from tf_keras.src import backend
from tf_keras.src import callbacks as cbks
from tf_keras.src.engine import training_utils
from tf_keras.src.engine import training_utils_v1
from tf_keras.src.utils import data_utils
from tf_keras.src.utils import generic_utils
from tf_keras.src.utils.mode_keys import ModeKeys

# isort: off
from tensorflow.python.platform import tf_logging as logging


def model_iteration(
    model,
    data,
    steps_per_epoch=None,
    epochs=1,
    verbose=1,
    callbacks=None,
    validation_data=None,
    validation_steps=None,
    validation_freq=1,
    class_weight=None,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False,
    shuffle=False,
    initial_epoch=0,
    mode=ModeKeys.TRAIN,
    batch_size=None,
    steps_name="steps",
    **kwargs,
):
    """Loop function for arrays of data with modes TRAIN/TEST/PREDICT.

    Args:
        model: TF-Keras Model instance.
        data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or `(x, y)` or
          `(x, y, sample_weights)`) or a generator or
          `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
        steps_per_epoch: Total number of steps (batches of samples) before
          declaring one epoch finished and starting the next epoch. Ignored with
          the default value of `None`.
        epochs: Number of times to iterate over the data.
        verbose: 0, 1, or 2. Verbosity mode.
          0 = silent, 1 = progress bar, 2 = one line per epoch.
          Note that the progress bar is not particularly useful when
          logged to a file, so verbose=2 is recommended when not running
          interactively (eg, in a production environment).
        callbacks: List of callbacks to be called during training.
        validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
          `(x, y)` or `(x, y, sample_weights)`) or a generator or
          `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
        validation_steps: Total number of steps (batches of samples) before
          declaring validation finished.
        validation_freq: Only relevant if validation data is provided. Integer
          or `collections.abc.Container` instance (e.g. list, tuple, etc.). If
          an integer, specifies how many training epochs to run before a new
          validation run is performed, e.g. `validation_freq=2` runs validation
          every 2 epochs. If a Container, specifies the epochs on which to run
          validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the
          end of the 1st, 2nd, and 10th epochs.
        class_weight: Dictionary mapping class indices to a weight for the
            class.
        max_queue_size: Integer. Maximum size for the generator queue. If
          unspecified, `max_queue_size` will default to 10.
        workers: Integer. Maximum number of processes to spin up when using
          process-based threading. If unspecified, `workers` will default to 1.
          If 0, will execute the generator on the main thread.
        use_multiprocessing: Boolean. If `True`, use process-based threading. If
          unspecified, `use_multiprocessing` will default to `False`. Note that
          because this implementation relies on multiprocessing, you should not
          pass non-pickleable arguments to the generator as they can't be passed
          easily to children processes.
        shuffle: Boolean. Whether to shuffle the order of the batches at the
          beginning of each epoch. Only used with instances of `Sequence`
          (`keras.utils.Sequence`). Has no effect when `steps_per_epoch` is not
          `None`.
        initial_epoch: Epoch at which to start training (useful for resuming a
          previous training run).
        mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
        batch_size: Integer batch size or None if unknown. Will only be used if
          `data` is in NumPy/Tensor format.
        steps_name: The string name of the steps argument, either `steps`,
          `validation_steps`, or `steps_per_epoch`. Only used for error message
          formatting.
        **kwargs: Additional arguments for backwards compatibility. `steps` is
          accepted as an alias for `steps_per_epoch`.

    Returns:
        - In TRAIN mode: `History` object.
        - In TEST mode: Evaluation metrics.
        - In PREDICT mode: Outputs of the Model called on inputs.

    Raises:
        ValueError: in case of invalid arguments.
    """
    if "steps" in kwargs:
        steps_per_epoch = kwargs["steps"]

    # Determine the number of steps per epoch and whether we should reset the
    # dataset at the end of each epoch.
    reset_dataset_after_each_epoch = False
    original_dataset = None
    is_dataset = isinstance(data, (tf.data.Dataset, tf.compat.v1.data.Dataset))
    if is_dataset:
        original_dataset = data
        if steps_per_epoch is None:
            reset_dataset_after_each_epoch = True
            steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
                model,
                data,
                steps_per_epoch,
                epochs=epochs,
                steps_name=steps_name,
            )

    # Convert to a format that supports `next(generator)`.
    generator, steps_per_epoch = convert_to_generator_like(
        data,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        epochs=epochs - initial_epoch,
        shuffle=shuffle,
    )

    do_validation = validation_data is not None
    is_sequence = isinstance(generator, data_utils.Sequence)
    _validate_arguments(
        is_sequence,
        is_dataset,
        use_multiprocessing,
        workers,
        steps_per_epoch,
        validation_data,
        validation_steps,
        mode,
        kwargs,
    )

    batch_function = _make_execution_function(
        model, mode, class_weight=class_weight
    )

    # Create the queue for the generator.
    enqueuer = None
    if not is_dataset:
        generator, enqueuer = _make_enqueued_generator(
            generator,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            max_queue_size=max_queue_size,
            shuffle=shuffle,
        )

    num_samples_or_steps, use_steps = _get_num_samples_or_steps(
        data, steps_per_epoch
    )

    count_mode = "steps" if use_steps else "samples"
    callbacks = cbks.configure_callbacks(
        callbacks,
        model,
        do_validation=do_validation,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        batch_size=batch_size,
        samples=num_samples_or_steps,
        count_mode=count_mode,
        verbose=verbose,
        mode=mode,
    )

    if mode == ModeKeys.PREDICT:
        aggregator = training_utils_v1.OutputsAggregator(
            True, steps=steps_per_epoch
        )
    else:
        aggregator = training_utils_v1.MetricsAggregator(
            True, steps=steps_per_epoch
        )

    should_set_learning_phase = tf.executing_eagerly() and model.run_eagerly
    if should_set_learning_phase:
        learning_phase_scope = backend.eager_learning_phase_scope(
            1 if mode == ModeKeys.TRAIN else 0
        )
        learning_phase_scope.__enter__()

    callbacks.model.stop_training = False
    callbacks._call_begin_hook(mode)

    initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
        initial_epoch, mode
    )

    for epoch in range(initial_epoch, epochs):
        if callbacks.model.stop_training:
            break

        # Setup work for each epoch.
        model.reset_metrics()
        epoch_logs = {}
        if mode == ModeKeys.TRAIN:
            callbacks.on_epoch_begin(epoch, epoch_logs)

        if steps_per_epoch is None:
            # Loop over dataset until `OutOfRangeError` is raised.
            target_steps = np.inf
        else:
            # Loop over dataset for the specified number of steps.
            target_steps = steps_per_epoch

        step = 0
        while step < target_steps:
            batch_data = _get_next_batch(generator)
            if batch_data is None:
                if is_dataset:
                    # The dataset passed by the user ran out of batches.  Now we
                    # know the cardinality of the dataset.  If steps_per_epoch
                    # was specified, then running out of data is unexpected, so
                    # we stop training and inform the user.
                    if steps_per_epoch:
                        callbacks.model.stop_training = True
                        logging.warning(
                            "Your dataset ran out of data; interrupting "
                            "training. Make sure that your dataset can "
                            "generate at least `%s * epochs` batches (in "
                            "this case, %d batches). You may need to use "
                            "the repeat() function when building your dataset."
                            % (steps_name, steps_per_epoch * epochs)
                        )
                    elif step > 0:
                        steps_per_epoch = step
                        aggregator.steps = steps_per_epoch
                else:
                    # We ran out of batches while the user passed an iterator
                    # (legacy).
                    callbacks.model.stop_training = True
                    logging.warning(
                        "Your dataset iterator ran out of data; "
                        "interrupting training. Make sure that your iterator "
                        "can generate at least `%s * epochs` "
                        "batches (in this case, %d batches). You may need to"
                        "use the repeat() function when building your "
                        "dataset." % (steps_name, steps_per_epoch * epochs)
                    )
                break

            # `batch_size` used for validation data if validation
            # data is NumPy/EagerTensors.
            batch_size = int(tf.nest.flatten(batch_data)[0].shape[0])

            # Callbacks batch begin.
            batch_logs = {"batch": step, "size": batch_size}
            callbacks._call_batch_hook(mode, "begin", step, batch_logs)

            is_deferred = not model._is_compiled
            batch_outs = batch_function(*batch_data)
            if not isinstance(batch_outs, list):
                batch_outs = [batch_outs]

            if step == 0:
                aggregator.create(batch_outs)

                if is_deferred:
                    # Set callbacks params. We do this here when model is
                    # compiled only in the first iteration of this loop
                    # (deferred build scenario).
                    cbks.set_callback_parameters(
                        callbacks,
                        model,
                        do_validation=do_validation,
                        batch_size=batch_size,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        samples=num_samples_or_steps,
                        verbose=verbose,
                        mode=mode,
                    )

            # Aggregate results.
            aggregator.aggregate(batch_outs)

            # Callbacks batch end.
            batch_logs = callbacks.make_logs(
                model, batch_logs, batch_outs, mode
            )
            callbacks._call_batch_hook(mode, "end", step, batch_logs)
            step += 1

            if callbacks.model.stop_training:
                break

        aggregator.finalize()
        results = aggregator.results
        epoch_logs = callbacks.make_logs(model, epoch_logs, results, mode)
        if len(results) == 1:
            results = results[0]

        # Run the test loop every epoch during training.
        if (
            do_validation
            and training_utils_v1.should_run_validation(validation_freq, epoch)
            and not callbacks.model.stop_training
        ):
            val_results = model_iteration(
                model,
                validation_data,
                steps_per_epoch=validation_steps,
                batch_size=batch_size,
                class_weight=class_weight,
                workers=workers,
                use_multiprocessing=use_multiprocessing,
                max_queue_size=max_queue_size,
                callbacks=callbacks,
                verbose=verbose,
                mode=ModeKeys.TEST,
                steps_name="validation_steps",
            )

            if not isinstance(val_results, list):
                val_results = [val_results]
            epoch_logs = callbacks.make_logs(
                model, epoch_logs, val_results, mode, prefix="val_"
            )

        if mode == ModeKeys.TRAIN:
            # Epochs only apply to `fit`.
            callbacks.on_epoch_end(epoch, epoch_logs)

        # Recreate dataset iterator for the next epoch.
        if reset_dataset_after_each_epoch and epoch < epochs - 1:
            generator = tf.compat.v1.data.make_one_shot_iterator(
                original_dataset
            )

    model._successful_loop_finish = True
    callbacks._call_end_hook(mode)

    if enqueuer is not None:
        enqueuer.stop()

    if should_set_learning_phase:
        learning_phase_scope.__exit__(None, None, None)

    if mode == ModeKeys.TRAIN:
        return model.history
    return results


# Maintain compatibility with the existing names.
fit_generator = functools.partial(model_iteration, mode=ModeKeys.TRAIN)
evaluate_generator = functools.partial(
    model_iteration, mode=ModeKeys.TEST, shuffle=False
)
predict_generator = functools.partial(
    model_iteration, mode=ModeKeys.PREDICT, shuffle=False
)


def _get_next_batch(generator):
    """Retrieves the next batch of input data."""
    try:
        generator_output = next(generator)
    except (StopIteration, tf.errors.OutOfRangeError):
        return None

    if not isinstance(generator_output, tuple):
        # Always wrap in a tuple.
        generator_output = (generator_output,)
    if len(generator_output) not in [1, 2, 3]:
        raise ValueError(
            "Output of generator should be a tuple of 1 or 2 or 3 "
            "elements: (input,) or (input, target) or "
            "(input, target, sample_weights). Received {}".format(
                generator_output
            )
        )
    return generator_output


def _validate_arguments(
    is_sequence,
    is_dataset,
    use_multiprocessing,
    workers,
    steps_per_epoch,
    validation_data,
    validation_steps,
    mode,
    kwargs,
):
    """Raises errors if arguments are invalid.

    Args:
      is_sequence: Boolean, whether data is a `keras.utils.data_utils.Sequence`
        instance.
      is_dataset: Boolean, whether data is a dataset instance.
      use_multiprocessing: Boolean. If `True`, use process-based threading. If
        unspecified, `use_multiprocessing` will default to `False`. Note that
        because this implementation relies on multiprocessing, you should not
        pass non-pickleable arguments to the generator as they can't be passed
        easily to children processes.
      workers: Integer. Maximum number of processes to spin up when using
        process-based threading. If unspecified, `workers` will default to 1. If
        0, will execute the generator on the main thread.
      steps_per_epoch: Total number of steps (batches of samples) before
        declaring one epoch finished and starting the next epoch. Ignored with
        the default value of `None`.
      validation_data: Either a tuple of NumPy/Tensor inputs (i.e. `(x,)` or
        `(x, y)` or `(x, y, sample_weights)`) or a generator or
        `keras.utils.data_utils.Sequence` object or Eager Iterator or Dataset.
      validation_steps: Total number of steps (batches of samples) before
        declaring validation finished.
      mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
      kwargs: Additional arguments for backwards compatibility.

    Raises:
      ValueError: If `steps_per_epoch` or `validation_steps` are not passed
        for data types that require them, or if unrecognized keyword
        arguments are passed.
    """
    if not is_sequence and use_multiprocessing and workers > 1:
        logging.warning(
            UserWarning(
                "Using a generator with `use_multiprocessing=True`"
                " and multiple workers may duplicate your data."
                " Please consider using the `keras.utils.Sequence`"
                " class."
            )
        )

    if steps_per_epoch is None and not is_dataset:
        arg_name = "steps_per_epoch" if mode == ModeKeys.TRAIN else "steps"
        raise ValueError(
            f"Please specify the number of steps via the `{arg_name}` argument."
        )

    val_gen = data_utils.is_generator_or_sequence(
        validation_data
    ) or isinstance(validation_data, tf.data.Iterator)
    if (
        val_gen
        and not isinstance(validation_data, data_utils.Sequence)
        and not validation_steps
    ):
        raise ValueError("Please specify the `validation_steps` argument.")

    if any(k != "steps" for k in kwargs):
        raise ValueError(
            f"Invalid arguments passed: {[k for k in kwargs if k != 'steps']}"
        )


def convert_to_generator_like(
    data, batch_size=None, steps_per_epoch=None, epochs=1, shuffle=False
):
    """Make a generator out of NumPy or EagerTensor inputs.

    Args:
      data: Either a generator or `keras.utils.data_utils.Sequence` object or
        `Dataset`, `Iterator`, or a {1,2,3}-tuple of NumPy arrays or
        EagerTensors.  If a tuple, the elements represent `(x, y,
        sample_weights)` and may be `None` or `[None]`.
      batch_size: Used when creating a generator out of tuples of NumPy arrays
        or EagerTensors.
      steps_per_epoch: Steps of the generator to run each epoch. If `None` the
        number of steps will be read from the data (for
        `keras.utils.data_utils.Sequence` types).
      epochs: Total number of epochs to run.
      shuffle: Whether the data should be shuffled.

    Returns:
      - Generator, `keras.utils.data_utils.Sequence`, or `Iterator`.

    Raises:
      - ValueError: If `batch_size` is not provided for NumPy or EagerTensor
        inputs.
    """
    if isinstance(data, tuple):
        # Scrub `Nones` that might have been passed for `targets`,
        # `sample_weights`.
        data = tuple(
            ele
            for ele in data
            if not all(e is None for e in tf.nest.flatten(ele))
        )

    if data_utils.is_generator_or_sequence(data) or isinstance(
        data, tf.data.Iterator
    ):
        if isinstance(data, data_utils.Sequence):
            if steps_per_epoch is None:
                steps_per_epoch = len(data)
        return data, steps_per_epoch
    if isinstance(data, tf.data.Dataset):
        return tf.compat.v1.data.make_one_shot_iterator(data), steps_per_epoch

    # Create generator from NumPy or EagerTensor Input.
    num_samples = int(tf.nest.flatten(data)[0].shape[0])
    if batch_size is None:
        raise ValueError(
            "When passing input data as arrays, do not specify "
            "`steps_per_epoch`/`steps` argument. "
            "Please use `batch_size` instead."
        )
    steps_per_epoch = int(math.ceil(num_samples / batch_size))

    def _gen(data):
        """Makes a generator out of a structure of NumPy/EagerTensors."""
        index_array = np.arange(num_samples)
        for _ in range(epochs):
            if shuffle:
                np.random.shuffle(index_array)
            batches = generic_utils.make_batches(num_samples, batch_size)
            for batch_start, batch_end in batches:
                batch_ids = index_array[batch_start:batch_end]
                flat_batch_data = training_utils.slice_arrays(
                    tf.nest.flatten(data), batch_ids, contiguous=(not shuffle)
                )
                yield tf.nest.pack_sequence_as(data, flat_batch_data)

    return _gen(data), steps_per_epoch


def _make_enqueued_generator(
    generator,
    workers=1,
    use_multiprocessing=False,
    max_queue_size=10,
    shuffle=False,
):
    """Create a buffered queue of next elements of the generator."""
    is_sequence = isinstance(generator, data_utils.Sequence)
    enqueuer = None
    if workers > 0:
        if is_sequence:
            enqueuer = data_utils.OrderedEnqueuer(
                generator,
                use_multiprocessing=use_multiprocessing,
                shuffle=shuffle,
            )
        else:
            enqueuer = data_utils.GeneratorEnqueuer(
                generator, use_multiprocessing=use_multiprocessing
            )
        enqueuer.start(workers=workers, max_queue_size=max_queue_size)
        output_generator = enqueuer.get()
    else:
        if is_sequence:
            output_generator = data_utils.iter_sequence_infinite(generator)
        else:
            output_generator = generator
    return output_generator, enqueuer


def _make_execution_function(model, mode, class_weight=None):
    """Makes function to run one step of model execution."""
    if mode == ModeKeys.TRAIN:
        f = functools.partial(model.train_on_batch, class_weight=class_weight)
    elif mode == ModeKeys.TEST:
        f = model.test_on_batch
    else:
        # Match signature of other modes to allow
        # 1, 2, or 3-tuples from generator
        def predict_on_batch(x, y=None, sample_weights=None):
            return model.predict_on_batch(x)

        f = predict_on_batch

    # Maintain stateful metrics across batch-level calls.
    if mode != ModeKeys.PREDICT:
        f = functools.partial(f, reset_metrics=False)

    return f


def _get_num_samples_or_steps(data, steps_per_epoch):
    """Returns number of samples or steps, and whether to use steps count
    mode."""
    flat_inputs = tf.nest.flatten(data)
    if hasattr(flat_inputs[0], "shape"):
        return int(flat_inputs[0].shape[0]), False
    return steps_per_epoch, True


class GeneratorOrSequenceTrainingLoop(training_utils_v1.TrainingLoop):
    """Generator-like.

    Input is Python generator, or Sequence object.

    The difference between this class and `GeneratorLikeTrainingFunction` is
    that this class only handles inputs that with x, y and sample_weight fused
    into one param.
    """

    def fit(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        epochs=1,
        verbose=1,
        callbacks=None,
        validation_split=0.0,
        validation_data=None,
        shuffle=True,
        class_weight=None,
        sample_weight=None,
        initial_epoch=0,
        steps_per_epoch=None,
        validation_steps=None,
        validation_freq=1,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
    ):
        model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
        training_utils_v1.check_generator_arguments(
            y, sample_weight, validation_split=validation_split
        )
        return fit_generator(
            model,
            x,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            verbose=verbose,
            callbacks=callbacks,
            validation_data=validation_data,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
            class_weight=class_weight,
            max_queue_size=max_queue_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
            shuffle=shuffle,
            initial_epoch=initial_epoch,
            steps_name="steps_per_epoch",
        )

    def evaluate(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        verbose=1,
        sample_weight=None,
        steps=None,
        callbacks=None,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
    ):
        model._validate_or_infer_batch_size(batch_size, steps, x)
        training_utils_v1.check_generator_arguments(y, sample_weight)
        return evaluate_generator(
            model,
            x,
            steps=steps,
            verbose=verbose,
            callbacks=callbacks,
            max_queue_size=max_queue_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )

    def predict(
        self,
        model,
        x,
        batch_size=None,
        verbose=0,
        steps=None,
        callbacks=None,
        max_queue_size=10,
        workers=1,
        use_multiprocessing=False,
    ):
        model._validate_or_infer_batch_size(batch_size, steps, x)
        return predict_generator(
            model,
            x,
            steps=steps,
            verbose=verbose,
            callbacks=callbacks,
            max_queue_size=max_queue_size,
            workers=workers,
            use_multiprocessing=use_multiprocessing,
        )


class EagerDatasetOrIteratorTrainingLoop(training_utils_v1.TrainingLoop):
    """A non-distributed Dataset or iterator in eager execution."""

    def fit(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        epochs=1,
        verbose=1,
        callbacks=None,
        validation_split=0.0,
        validation_data=None,
        shuffle=True,
        class_weight=None,
        sample_weight=None,
        initial_epoch=0,
        steps_per_epoch=None,
        validation_steps=None,
        validation_freq=1,
        **kwargs,
    ):
        model._validate_or_infer_batch_size(batch_size, steps_per_epoch, x)
        # Make sure that y, sample_weights, validation_split are not passed.
        training_utils_v1.validate_dataset_input(
            x, y, sample_weight, validation_split
        )
        if (
            isinstance(x, (tf.compat.v1.data.Dataset, tf.data.Dataset))
            and shuffle
        ):
            training_utils_v1.verify_dataset_shuffled(x)

        return fit_generator(
            model,
            x,
            steps_per_epoch=steps_per_epoch,
            epochs=epochs,
            verbose=verbose,
            callbacks=callbacks,
            validation_data=validation_data,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
            class_weight=class_weight,
            workers=0,
            shuffle=shuffle,
            initial_epoch=initial_epoch,
            steps_name="steps_per_epoch",
        )

    def evaluate(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        verbose=1,
        sample_weight=None,
        steps=None,
        callbacks=None,
        **kwargs,
    ):
        model._validate_or_infer_batch_size(batch_size, steps, x)
        # Make sure that y, sample_weights, validation_split are not passed.
        training_utils_v1.validate_dataset_input(x, y, sample_weight)
        return evaluate_generator(
            model,
            x,
            steps=steps,
            verbose=verbose,
            workers=0,
            callbacks=callbacks,
        )

    def predict(
        self,
        model,
        x,
        batch_size=None,
        verbose=0,
        steps=None,
        callbacks=None,
        **kwargs,
    ):
        model._validate_or_infer_batch_size(batch_size, steps, x)
        return predict_generator(
            model,
            x,
            steps=steps,
            verbose=verbose,
            workers=0,
            callbacks=callbacks,
        )


class GeneratorLikeTrainingLoop(training_utils_v1.TrainingLoop):
    """TrainingLoop that handle inputs like python generator.

    This is the default handler for most of the input data types, includes
    symbolic tensors or Numpy array-like, Datasets and iterators in graph mode
    (since they generate symbolic tensors). This Function is used to handle
    model with `run_eagerly` = True.
    """

    def fit(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        epochs=1,
        verbose=1,
        callbacks=None,
        validation_split=0.0,
        validation_data=None,
        shuffle=True,
        class_weight=None,
        sample_weight=None,
        initial_epoch=0,
        steps_per_epoch=None,
        validation_steps=None,
        validation_freq=1,
        **kwargs,
    ):
        batch_size = model._validate_or_infer_batch_size(
            batch_size, steps_per_epoch, x
        )
        x, y, sample_weights = model._standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            class_weight=class_weight,
            batch_size=batch_size,
            check_steps=True,
            steps_name="steps_per_epoch",
            steps=steps_per_epoch,
            validation_split=validation_split,
            shuffle=shuffle,
        )

        if validation_data:
            validation_data = model._prepare_validation_data(
                validation_data, batch_size, validation_steps
            )
        elif validation_split and 0.0 < validation_split < 1.0:
            (
                x,
                y,
                sample_weights,
                val_x,
                val_y,
                val_sample_weights,
            ) = training_utils_v1.split_training_and_validation_data(
                x, y, sample_weights, validation_split
            )
            validation_data = (val_x, val_y, val_sample_weights)
        else:
            if validation_steps:
                raise ValueError(
                    "`validation_steps` should not be specified if "
                    "`validation_data` is None."
                )

        return fit_generator(
            model,
            (x, y, sample_weights),
            steps_per_epoch=steps_per_epoch,
            batch_size=batch_size,
            epochs=epochs,
            verbose=verbose,
            callbacks=callbacks,
            validation_data=validation_data,
            validation_steps=validation_steps,
            validation_freq=validation_freq,
            workers=0,
            shuffle=shuffle,
            initial_epoch=initial_epoch,
            steps_name="steps_per_epoch",
        )

    def evaluate(
        self,
        model,
        x=None,
        y=None,
        batch_size=None,
        verbose=1,
        sample_weight=None,
        steps=None,
        callbacks=None,
        **kwargs,
    ):
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        x, y, sample_weights = model._standardize_user_data(
            x,
            y,
            sample_weight=sample_weight,
            batch_size=batch_size,
            check_steps=True,
            steps_name="steps",
            steps=steps,
        )
        return evaluate_generator(
            model,
            (x, y, sample_weights),
            steps=steps,
            batch_size=batch_size,
            verbose=verbose,
            workers=0,
            callbacks=callbacks,
        )

    def predict(
        self,
        model,
        x,
        batch_size=None,
        verbose=0,
        steps=None,
        callbacks=None,
        **kwargs,
    ):
        batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
        x, _, _ = model._standardize_user_data(
            x, check_steps=True, steps_name="steps", steps=steps
        )
        return predict_generator(
            model,
            x,
            steps=steps,
            batch_size=batch_size,
            verbose=verbose,
            workers=0,
            callbacks=callbacks,
        )

