PyTorchでDQNを実装した | moskomule log

PyTorchでDQNを実装した

これはPyTorch Advent Calendar1日目の記事です.

もしもサンタさんがこの記事を見ていたら,と考えてwish listを用意しました.サンタ各位におかれましては,よろしくお願いします.

はじめに

DQN(Deep Q Network)は Minh et al. 20151(以下論文)で登場した深層強化学習の先駆けです.Atariのゲームで非常に高い得点を修めるというパフォーマンスで有名になりました.

9月頃に強化学習の勉強をした際に実装してみたのですが,一向に学習が進まず放置していたのですが,最近Implementing the Deep Q-Network 2を読み再開してみたところ,動いてしまったので,この記事を書くことになりました.

今回の実装はこちらにあります.

強化学習とは

David Silver先生に聞きましょう.ただしこの講義では深層強化学習は扱われていません.

Deep Q-Networkとは

論文を読みましょう.Q-Learningの応用で,複雑ではありませんが,学習を安定させるための工夫が各所にあるので見逃すと動かないようです.

DQNの学習アルゴリズム.論文より.

実装について

画像の処理

DQNではAtariのゲームの画像をグレースケールにしてスタックするなどの処理がありますが,このあたりは各アルゴリズムをTensorFlowで実装し,公開しているOpen AI baselinesを一部変更して用いています.

  • OpenCV2をPillowに変更した
  • 画像のスタックの仕方をPyTorchに合わせて変更した.

また今回の改良ではtensorboard-pytorchを導入して,入力画像が正しいかを確認できるようにしました.

ネットワーク

Deepとは言えない気がしますが論文通りの構成です.何か工夫すると多少変わるのかもしれません.

class DQN(nn.Module):
    def __init__(self, output_size: int):
        super(DQN, self).__init__()
        self.feature = nn.Sequential(
                nn.Conv2d(4, 32, kernel_size=8, stride=4),
                nn.ReLU(inplace=True),
                nn.Conv2d(32, 64, kernel_size=4, stride=2),
                nn.ReLU(inplace=True),
                nn.Conv2d(64, 64, kernel_size=3, stride=1),
                nn.ReLU(inplace=True))
        self.fc = nn.Linear(64 * 7 * 7, 512)
        self.output = nn.Linear(512, output_size)

    def forward(self, x):
        x = self.feature(x)
        x = x.view(x.size()[0], -1)
        x = F.relu(self.fc(x))
        x = self.output(x)
        return x

目標の行動価値函数$\hat{Q}$のネットワークを行動価値函数$Q$のネットワークの値で更新する際に

self.target_net.load_state_dict(self.net.state_dict())
for p in self.target_net.parameters():
    p.requires_grad = False

によって$\hat{Q}$の勾配を計算しないようにしています.

その他

以前の実装では見落としていていたものを挙げます.

noop max

“maximum number of ‘do nothing’ actions to be performed by the agent at the start of an episode”.エピソードの開始時に最大でnoop回何もしないことで,状態を多様化します.

update frequency

“the number of actions selected by the agent between successive SGD updates” とあって更新しないの?と思って無視していたのですが,4回毎にしか重みの更新をしないことでロバストになるのかもしれません.

learning rate

2によれば,論文の実装で使われているRMSpropにはmomentumがあるのですが,PyTorchなどのRMSprop実装にはmomentumがないのでlr=5e-5と小さめに設定します.

final exploration alternative

2に従ってε-greedy探索の探索率$\epsilon$を1から0.05(1では0.1)まで下げます.

validation用の環境の用意

2に250,000ステップ毎にvalidationをした,とあったのでvalidation環境を用意しました.この際にはnoop maxは不要です.

損失

1では2乗損失を用いていますが,3に従ってHuber損失(smoothed L1)を用います.

\[L(a)=\begin{cases} & \frac12a^2 ~~~~\text{ if } |a|\leq 1 \cr & |a| - \frac12 ~~~~\text{ otherwise} \end{cases}\]

実験

git clone https://github.com/moskomule/pytorch.rl.learning.git
cd pytorch.rl.learning.git
export PYTHONPATH=$(pwd)
cd dl/dqn
python exec.py

で卓球ゲームであるPongを学習します.--envで他のゲームを試すことも出来ます.

結果

今のところ卓球ゲームであるPongを試しています.1エピソードで20回の対戦があり勝つと+1,負けると-1の報酬を受け取ります.validation環境での報酬のグラフを見るとはじめはほぼ全敗で-20点(-21点のこともある)ですが,学習が進むにつれて安定して全勝して20点を獲得できるようになります.

Pongにおけるvalidation環境での報酬の変化.

Pongにおける損失の変化.

tensorboard-pytorchのお陰で実験の結果が簡単に可視化できるので便利です.

今後

DQNだと深層強化学習楽しい,という感じですが最近は画像処理や自然言語処理などと組み合わせた研究も多く,不可欠の技術になりつつあるのかもしれません.

一方でHenderson et al. 20174で言われるように,提案されたstate of the artの手法の再現がなかなか出来ないという問題もあります.

今後はとりあえずbaselinesにあるような代表的な深層強化学習のアルゴリズムを再実装を通して,深層強化学習に親しんでいこうと考えています.


  1. Mnih, V., Kavukcuoglu, K., Silver, D., Rusu, A. A., Veness, J., Bellemare, M. G., Hassabis, D. (2015). Human-level control through deep reinforcement learning. [return]
  2. Roderick, M., MacGlashan, J., & Tellex, S. (2017). Implementing the Deep Q-Network. [return]
  3. Szymon Sidor & John Schulman. (2017) OpenAI Baselines: DQN [return]
  4. https://arxiv.org/abs/1709.06560 [return]
comments powered by Disqus