Git Product home page Git Product logo

meta-learning-for-everyone's Introduction

Language grade: Python License: Apache 2.0 Python 3.8.8 PyTorch 1.9.1 Code style: black Imports: isort Linting: flake8 & pylint

All Contributors

This repository supports English. If you use English, move to language/english branch.

모두를 위한 메타러닝: PyTorch를 활용한 Few-shot 학습 모델과 빠른 강화학습 에이전트 만들기

"모두를 위한 메타러닝" 책에 대한 코드 레포지토리입니다.

필요 조건

이 레포지토리에서는 python 3.8.15 버전을 사용합니다.

설치 및 사용 방법

1. Anaconda 설치

먼저, 아래의 링크에서 Anaconda를 설치합니다.

https://www.anaconda.com/

2. Anaconda 환경 만들기

다음으로, 아래의 명령어들을 통해 새로운 python 환경을 만들고, 그 환경을 활성화합니다.

(base) $ conda create -y -n meta python=3.8.8

(base) $ conda activate meta

(meta) $ conda env list

3. 패키지 설치

이어서, 이 레포지토리를 clone한 뒤, 다음의 명령어를 실행하여 필요한 패키지들을 설치해주세요.

MacOS 및 Linux 사용자

# 사용자
(meta) $ make init

# 개발자
(meta) $ make init-dev

Windows 사용자

# 사용자
(meta) $ "./scripts/window-init.bat"

4. 모델 학습 및 결과 확인

Meta-SL

Meta-SL은 각 알고리즘 폴더로 이동하여 jupyter notebook을 이용하여 해당 알고리즘을 실행해주시고 결과를 확인해주세요.

(meta) $ jupyter notebook

Colab을 이용하실 경우, Colab에서 Torchmeta 설치하기 가이드를 참고하여 Torchmeta를 설치하고 이용해주세요.

Meta-RL

Meta-RL은 각 알고리즘 폴더로 이동하여 아래의 명령어들을 이용하여 실행해주세요.

# RL^2
(meta) $ python rl2_trainer.py

# MAML
(meta) $ python maml_trainer.py

# PEARL
(meta) $ python pearl_trainer.py

Meta-RL의 경우, 텐서보드를 이용하여 학습 결과를 확인해주세요.

(meta) $ tensorboard --logdir=./results

Contributors ✨

Thanks goes to these wonderful people (emoji key):


Dongmin Lee

💻 📖

Seunghyun Lee

💻 📖

Luna Jang

💻

Seungjae Ryan Lee

💻

This project follows the all-contributors specification. Contributions of any kind welcome!

meta-learning-for-everyone's People

Contributors

allcontributors[bot] avatar clyde21c avatar dependabot[bot] avatar dongminlee94 avatar lgtm-com[bot] avatar sjyoondeltar avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

meta-learning-for-everyone's Issues

About Productivity

Hi, friends. I happen to find this great repo! It seems much elegant than most other meta-rl implementations. I wonder if the Feature branch is tested on environments like cheetah and can be readily used for other environments. Thanks!

`torch.lstsq()` is deprecated in Linear-feature Baseline of MAML

torch.lstsq() is deprecated in favor of torch.linalg.lstsq() in Pytorch version >= 1.9.0.

for more stable codes we might modify the function for users with higher version of Pytorch as followed.

if hasattr(torch, 'lstsq'):  # Required for torch < 1.9.0
    coeffs = torch.lstsq(b, A).solution
else:
    coeffs = torch.linalg(A, b).solution

`argv[0]=` is occurred at the beginning of the training.

Describe the bug

argv[0]= is occurred at the beginning of the training.

To Reproduce

Steps to reproduce the behavior:

  1. run python {algorighm}_trainer.py
  2. then argv[0]= is occurred on the terminal as followed
    스크린샷, 2021-09-23 08-11-22

Expected behavior

A clear and concise description of what you expected to happen.

  • It may occurred from somewhere of the PyBullet environment.

Additional context

Add any other context about the problem here.

Out of memory

Describe the bug

Killed error is occurred on laptop notebooks to be the cause of out of memory
스크린샷, 2022-01-18 02-44-04
스크린샷, 2022-01-18 02-46-51

To Reproduce

Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior

A clear and concise description of what you expected to happen.

Additional context

Add any other context about the problem here.

Gather common modules

This issue is to gather the common modules (e.g., sampler, buffer). Because the repo's codes currently have too many duplication codes.

Colab에서 torchmeta import 문제

현재 Colab에서 해당 코드 실행시 torchmeta가 import가 안되는 문제가 발생합니다.

파이썬 버전 문제인가 싶어서 python3.7로 실행했는데,
패키지 설치까진되고, 따로 import가 안됩니다.

혹시 해결하신분 계신가요?

Installing collected packages: urllib3, typing-extensions, tqdm, Pillow, ordered-set, numpy, idna, charset-normalizer, certifi, torch, requests, h5py, torchvision, torchmeta
Successfully installed Pillow-9.5.0 certifi-2022.12.7 charset-normalizer-3.1.0 h5py-3.8.0 idna-3.4 numpy-1.21.6 ordered-set-4.1.0 requests-2.30.0 torch-1.9.1 torchmeta-1.8.0 torchvision-0.10.1 tqdm-4.65.0 typing-extensions-4.5.0 urllib3-2.0.2
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv

image

Check whether the implemented "Higher" module works

In MAML we uses "Higher" module which is a library providing support for higher-order optimization developed by Facebook.
https://github.com/facebookresearch/higher
It turns existing torch.nn.Module instances "stateless", meaning that changes to the parameters thereof can be tracked.

Therefore, it is needed to be checked whether the implemented Higher module really tracks the parameters of policy through the outer loop and inner loop.

Remove `while cur_samples < max_samples` from sampler module

Change the code from

def collect_train_data(self, task_index, max_samples, update_posterior, add_to_enc_buffer):
        """Data collecting for meta-train"""
        self.agent.encoder.clear_z()
        self.agent.policy.is_deterministic = False

        cur_samples = 0
        while cur_samples < max_samples:
            trajs, num_samples = self.sampler.obtain_samples(
                max_samples=max_samples - cur_samples,
                update_posterior=update_posterior,
                accum_context=False,
            )
            cur_samples += num_samples

            self.rl_replay_buffer.add_trajs(task_index, trajs)
            if add_to_enc_buffer:
                self.encoder_replay_buffer.add_trajs(task_index, trajs)

            if update_posterior:
                context_batch = self.sample_context([task_index])
                self.agent.encoder.infer_posterior(context_batch)

def obtain_samples(self, max_samples, update_posterior, accum_context=True):
        """Obtain samples up to the number of maximum samples"""
        trajs = []
        cur_samples = 0

        while cur_samples < max_samples:
            traj = self.rollout(accum_context=accum_context)
            trajs.append(traj)
            cur_samples += len(traj["cur_obs"])
            self.agent.encoder.sample_z()

            if update_posterior:
                break
        return trajs, cur_samples

to

def collect_train_data(self, task_index, max_samples, update_posterior, add_to_enc_buffer):
        """Data collecting for meta-train"""
        self.agent.encoder.clear_z()
        self.agent.policy.is_deterministic = False

        trajs, num_samples = self.sampler.obtain_samples(
            max_samples=max_samples,
            accum_context=False,
        )

        self.rl_replay_buffer.add_trajs(task_index, trajs)
        if add_to_enc_buffer:
            self.encoder_replay_buffer.add_trajs(task_index, trajs)

        if update_posterior:
            context_batch = self.sample_context([task_index])
            self.agent.encoder.infer_posterior(context_batch)

def obtain_samples(self, max_samples, accum_context=True):
        """Obtain samples up to the number of maximum samples"""
        trajs = []
        cur_samples = 0

        while cur_samples < max_samples:
            traj = self.rollout(accum_context=accum_context)
            trajs.append(traj)
            cur_samples += len(traj["cur_obs"])
            self.agent.encoder.sample_z()
        return trajs, cur_samples

Then, experiment PEARL

Add early stopping condition, saving, and loading

This issue is an issue that must adds the list below.

  • Early stopping condition
  • Models, buffers, and (best, final) checkpoints saving
  • Models, buffers, and (best, final) checkpoints loading
  • Train tasks and test tasks distribution (optional)

The issue is related to #22

put agent into sampler in MAML

Currently get_action is misplaced in sampler.py.
To remove it, it is needed to check whether there are any problems in putting agent into the sampler.py.

recommended code: self.agent.policy = inner_policy

task related configurations

  • train_tasks: the total number of tasks to train
  • num_sample_tasks: the number of sampled tasks for each iteration

Add mujoco rendering

Is your feature request related to a problem? Please describe.

A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]

Describe the solution you'd like

A clear and concise description of what you want to happen.

Additional context

Add any other context or screenshots about the feature request here.

torchmeta import 오류

아나콘다 프롬프트로 패키지 설치까지는 완료하였는데, 이후 jupyter notebook에서 입력해보니 torchmeta를 import하지 못하는 오류가 발생하고 있습니다.
혹시 해결방법이 있을까요?
캡처

Modify env class attributes

Is your feature request related to a problem? Please describe.

  • In cheetah-dir environment, self.tasks uses dictionary unnecessarily.
directions = [-1, 1, -1, 1]
self.tasks = [{"direction": direction} for direction in directions]

https://github.com/dongminlee94/meta-rl/blob/develop/src/envs/half_cheetah_dir.py#L19

  • self._goal is not used in both cheetah-dir and cheetah-vel environments
def __init__(self, num_tasks=2, seed=0):
        super().__init__(render=False)
        self.tasks = self.sample_tasks(num_tasks)
        self._goal_vel = self.tasks[0].get("velocity", 0.0)
        self._goal = self._goal_vel
        self._task = None
        self._alive = None
        self.rewards = None
        self.potential = None
        self.seed(seed)

'''
'''
'''

def reset_task(self, index):
        """Reset velocity target to index of task"""
        self._task = self.tasks[index]
        self._goal_vel = self._task["velocity"]
        self._goal = self._goal_vel
        self.reset()

Describe the solution you'd like

  • Change self.tasks from dict to list
  • Remove self._goal attribute

Additional context

Add any other context or screenshots about the feature request here.

meta-test process for MAML

need to determine how to implement meta-test process of MAML algorithm.

  • whether to iteratively perform second-order optimization
  • how about in other MAML repos?

How to implement a value_function in MAML?

In the original MAML paper, the author implements a value function (vf) as a linear function and fits it with batch data every time.
Learn2Learn repo implement vf exactly same with the paper.
However, in Ray Project repo, vf is implemented as a 2 layered neural network and updated simultaneously with a policy network. In this case, vf is shared through tasks and iterations.
스크린샷, 2021-08-30 00-13-04

we need to determine how to implement this vf.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.