import pybullet_utils.mpi_util as MPIUtil
"""

Some simple logging functionality, inspired by rllab's logging.
Assumes that each diagnostic gets logged each iteration

Call logz.configure_output_file() to start logging to a 
tab-separated-values file (some_file_name.txt)

To load the learning curves, you can do, for example

A = np.genfromtxt('/tmp/expt_1468984536/log.txt',delimiter='\t',dtype=None, names=True)
A['EpRewMean']

"""

import os.path as osp, shutil, time, atexit, os, subprocess


class Logger:

  def print2(str):
    if (MPIUtil.is_root_proc()):
      print(str)
    return

  def __init__(self):
    self.output_file = None
    self.first_row = True
    self.log_headers = []
    self.log_current_row = {}
    self._dump_str_template = ""
    return

  def reset(self):
    self.first_row = True
    self.log_headers = []
    self.log_current_row = {}
    if self.output_file is not None:
      self.output_file = open(output_path, 'w')
    return

  def configure_output_file(self, filename=None):
    """
        Set output directory to d, or to /tmp/somerandomnumber if d is None
        """
    self.first_row = True
    self.log_headers = []
    self.log_current_row = {}

    output_path = filename or "output/log_%i.txt" % int(time.time())

    out_dir = os.path.dirname(output_path)
    if not os.path.exists(out_dir) and MPIUtil.is_root_proc():
      os.makedirs(out_dir)

    if (MPIUtil.is_root_proc()):
      self.output_file = open(output_path, 'w')
      assert osp.exists(output_path)
      atexit.register(self.output_file.close)

      Logger.print2("Logging data to " + self.output_file.name)
    return

  def log_tabular(self, key, val):
    """
        Log a value of some diagnostic
        Call this once for each diagnostic quantity, each iteration
        """
    if self.first_row and key not in self.log_headers:
      self.log_headers.append(key)
    else:
      assert key in self.log_headers, "Trying to introduce a new key %s that you didn't include in the first iteration" % key
    self.log_current_row[key] = val
    return

  def get_num_keys(self):
    return len(self.log_headers)

  def print_tabular(self):
    """
        Print all of the diagnostics from the current iteration
        """
    if (MPIUtil.is_root_proc()):
      vals = []
      Logger.print2("-" * 37)
      for key in self.log_headers:
        val = self.log_current_row.get(key, "")
        if isinstance(val, float):
          valstr = "%8.3g" % val
        elif isinstance(val, int):
          valstr = str(val)
        else:
          valstr = val
        Logger.print2("| %15s | %15s |" % (key, valstr))
        vals.append(val)
      Logger.print2("-" * 37)
    return

  def dump_tabular(self):
    """
        Write all of the diagnostics from the current iteration
        """
    if (MPIUtil.is_root_proc()):
      if (self.first_row):
        self._dump_str_template = self._build_str_template()

      vals = []
      for key in self.log_headers:
        val = self.log_current_row.get(key, "")
        vals.append(val)

      if self.output_file is not None:
        if self.first_row:
          header_str = self._dump_str_template.format(*self.log_headers)
          self.output_file.write(header_str + "\n")

        val_str = self._dump_str_template.format(*map(str, vals))
        self.output_file.write(val_str + "\n")
        self.output_file.flush()

    self.log_current_row.clear()
    self.first_row = False
    return

  def _build_str_template(self):
    num_keys = self.get_num_keys()
    template = "{:<25}" * num_keys
    return template
