Git Product home page Git Product logo

ppo_test's Introduction

ppo_test

PPO implementation for InvertedPendulumEnv

Задание

С помощью алгоритма Proximal Policy Optimization для системы маятник на тележке решить задачу подъема маятника из нижнего положения в верхнее с последующей стабилизацией.

Ожидаем имплементацию всего пайплайна PPO с использованием только пакета pytorch. При этом можно ориентироваться на готовые решения.

Решение

Для реализации алгоритма PPO использовалась оригинальная статья и книга с теорией

Также ориентировалась на курс ШАДА.

Посмотрела также реализацию в Pytorch, но там странно считаются advantages - не стала на нее ориентироваться.

Достижение наилучшего результата

В процессе решения менялась функция наград в классе InvertedPendulumEnv.

Были протестированы различные функции. Их вид и результат тестирование представлены ниже.

Для того, чтобы протестировать среду с ними, необходимо вставить их в функцию:

def get_reward(ob, a):

Вариант №1

theta = np.mod(ob[1], 2*np.pi) # [0; 2*pi]
theta = (theta - 2*np.pi) if theta > np.pi else theta # [-pi; pi]
if abs(ob[0]) > 0.8:
  out_of_bound = 1
else:
  out_of_bound = 0
x = abs(ob[0]) # ob[0] in [-1.1, 1.1]
x_change_reward = -x ** 2
reward += 0.3 * x_change_reward
if abs(theta) < 0.1:
  reward += 10
else:
  reward = 2.5 * np.cos(theta) - 0.01*(ob[3])**2 - 0.1*a[0]**2 - 10*out_of_bound
var1.mp4

Вариант №2

reward = 0
theta = np.mod(ob[1], 2*np.pi) # [0; 2*pi]
theta = (theta - 2*np.pi) if theta > np.pi else theta # [-pi; pi]      
if abs(ob[0]) > 0.9:
  out_of_bound = 1
else:
  out_of_bound = 0

reward = np.cos(theta) - 0.001*(ob[3])**2 - 0.01*a[0]**2 - 1*out_of_bound
var2.mp4

Вариант №3

reward = 0
theta = np.mod(ob[1], 2*np.pi) # [0; 2*pi]
theta = (theta - 2*np.pi) if theta > np.pi else theta # [-pi; pi]

if abs(theta) > 0.9:
  reward -= 1
  coef_velocity = 1
else:
  print(theta, ob[2], ob[3])
  reward += np.exp(1 - abs(theta))
  coef_velocity = -(abs(theta) - 1)

if abs(theta) < 0.2:
  reward += 100/(0.3*ob[0]**2)

if abs(ob[0]) > 0.8:
  reward -= 10

swing_up =  1 - abs(theta) / np.pi
reward += swing_up + coef_velocity*abs(ob[3])**2 - 0.4*abs(a[0])/3 + ob[0] * 0.2*abs(ob[2]) # более плавно но перелетает все равно
var3.mp4

Вариант №4

    theta = np.mod(ob[1], 2*np.pi) # [0; 2*pi]
    theta = (theta - 2*np.pi) if theta > np.pi else theta # [-pi; pi]
    reward = -(theta**2 + 0.1*ob[3]**2 + 2*ob[0]**2)
    if abs(theta) < 0.1:
        reward += 0.1 * np.cos(theta) - 0.1*ob[3]**2
best.mp4

Наилучший результат

Наиболее похожее на ТЗ состояние:

best_new.mp4

Маятник смог подняться в вертикальное состяние и немного его удержать.

Такого состояния удалось достичь при функции наград вида:

theta = np.mod(ob[1], 2*np.pi) # [0; 2*pi]
theta = (theta - 2*np.pi) if theta > np.pi else theta # [-pi; pi]
reward = -(theta**2 + 0.1*ob[3]**2 + 2*ob[0]**2)
if abs(theta) < 0.1:
  reward += 0.1 * np.cos(theta) - 0.1*ob[3]**2

Мысли

Я попробовала разные функции наград, но честно не знаю, почему не получается удерживать маятник долгое время. То есть как забрасывать маятник понятно, что учитывать при балансе - спорный вопрос. Возможно требуется больше эпох обучения, на эксперименты особо не было времени. Может быть стоит по-другому реализовать процесс обучения или переписать PPO. Если останется время, попробую это сделать и посмотрю на результат.

В любом случае задачка была интресная, спасибо.

Evaluation

Быстро протестировать в google colab - Open In Colab

ppo_test's People

Contributors

sadevans avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.