#add parent dir to find package. Only needed for source code build, pip install doesn't need it.
import os, inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(os.path.dirname(currentdir))
os.sys.path.insert(0, parentdir)

import gym
from pybullet_envs.bullet.cartpole_bullet import CartPoleBulletEnv

from baselines import deepq


def callback(lcl, glb):
  # stop training if reward exceeds 199
  is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
  return is_solved


def main():

  env = CartPoleBulletEnv(renders=False)
  model = deepq.models.mlp([64])
  act = deepq.learn(env,
                    q_func=model,
                    lr=1e-3,
                    max_timesteps=100000,
                    buffer_size=50000,
                    exploration_fraction=0.1,
                    exploration_final_eps=0.02,
                    print_freq=10,
                    callback=callback)
  print("Saving model to cartpole_model.pkl")
  act.save("cartpole_model.pkl")


if __name__ == '__main__':
  main()
