PyTorchでRNN入門 | moskomule log

PyTorchでRNN入門

RNNの概説

RNNは再帰型ニューラルネットワーク(recurrent neural network)の略です.各層は以前の自分自身の出力も入力とする再帰的な構造をもつため,この名がつけられています.時間依存のある文や信号といった入力を処理することができます.

RNN(左)とCNNなどのネットワーク(右).RNNは自分自身の出力も入力として取り込むことで,時間に依存した情報を扱うことができると考えられた.

RNN自体は90年代初頭にJ.Elmanらによって提案され1,文生成や分散表現の獲得などの研究が行われています.現在でもよく使われる,より長い系列にも対応できるLSTMも90年代末に提案されており,伝統のあるネットワークであるといえるでしょう.

画像におけるCNNの華々しい活躍と比較すると劣りますが,それでもGoogle翻訳の昨今の「自然な」翻訳の背景にはRNNがあります.

単純なRNN

入力とする系列$x_0,x_1,\cdots,x_T$を$x_0$から順次与えていきます.ここでは下付き文字$\star_t$は時刻を表します.また上付き文字$\star^l$を$l$層目の状態として,時刻$t$における$l$層目の隠れ状態を$h_t^l$,出力を$y_t$と表します.

再帰型ではないニューラルネットワークでは,ある層$l$の隠れ状態$h^l$は,その前の層への状態に重み$W^l$をかけ,活性化函数$f$に与えたもので,

\[h^l = f(W^lh^{l-1})\]

でした(ただし簡便のためにバイアスは省きました.今後も同様です.).

一方で,RNNには時刻の概念があり,さらに一つ前の状態を考慮するため,ある時刻$t$における,層$l$の状態$h^l_t$は

\[h^l_t = f(W^lh_{t}^{l-1}+U^lh_{t-1}^l)\]

です.つまり,前の層の出力$W^lh_t^{l-1}$に,前時刻の自分の出力$U^lh_{t-1}^l$が加わったものを活性化函数に与えることとなります.活性化函数$f$としては$\tanh,\mathrm{relu},\mathrm{sigmoid}$などが用いられます.

それでは実際に系列を入力してみましょう.まず,$x_0$を入力します.

時刻 $\sim 0$

このとき$t=0$では,上の式から

\[h^1_0=f(W^1x_0+U^1h_{-1}^1),h^2_0=f(W^2h^1_0+U^2h_{-1}^2)\]

となります.この$h_{-1}^1,h_{-1}^2$は最初は隠れ状態がないために与える必要がある「仮の隠れ状態」で,$0$など適当に初期化されたベクトルを用います.同様にして,recurrent層が$L$層あれば時刻0において$x_0$と$h_{-1}^1,h_{-1}^2,\cdots,h_{-1}^L$を用意する必要があります.また,最終層は

\[y_t=f_y(W^{L}h_t^{L})\]

で与えられます.

その後は隠れ状態があるので,順次

\[h^1_1=f(W^1x_1+U^lh_{0}^1)\]

などとなります.

時刻 $0\sim 1$

時刻 $1\sim 2$

時刻 $2\sim $

重みの更新は一つの系列が終了してから行います.このとき用いる損失は,目標を$d_0,d_1,\cdots,d_T$として,すべての時刻に対して出力が必要な場合,例えば文章生成の場合,

\[\sum_t\mathrm{loss}(y_t,d_t)\]

とします.または,二値分類などでは$y_T$には$y_0,\cdots,y_{T-1}$の情報が蓄積されていると考えて

\[\mathrm{loss}(y_T, d_T)\]

を用います.

いずれにしても,このとき$t=T$での損失から$t=0$での隠れ状態も考慮することとなります.上の図では$t=2$までしかありませんが,$h_0^1$から$y_2$までの経路は,例えば$h_0^1\to h_1^1\to h_1^2\to h_2^2\to y_2$などのように,一般のネットワークでは隠れ層4のネットワークに相当します.

そのため,理論的には長い系列を処理することができますが,実際にはこのような単純なRNNでは容易に勾配消失がおこり,長い系列は学習できなくなることが知られています.

PyTorchではこの単純なRNNはnn.RNNに用意されています(後述).

Gated RNN

CNNでは勾配消失を防ぎつつより深く層を重ねるためにResNetなどが発表されています.

ResNetの一部の概略図

これらは簡略化すれば

\[h^l=f(h^{l-1})+h^{l-1}\]

と表されます.これによって第1項の勾配が0になるような場合であっても,全体の勾配が消失しないことが期待されます.

これを一般化して,要素積$\odot$を使い

\[h^l = g\odot f(h^{l-1})+i\odot h^{l-1}\]

を考えます.上は$g=i=1$の特殊な場合です.ただし,RNNの場合問題となるのは時間方向ですから,以下では

\[h_t = g_{t}\odot f(h_{t-1})+i_{t}\odot h_{t-1}\]

を考えます.これは再帰的に

\[g_t\odot f(h_{t-1})+\sum_{s=1}^{t-1}(\bigodot_{\tau=s+1}^t i_{\tau})\odot g_{s}\odot f(h_{s-1})+(\bigodot_{\tau=1}^t i_{\tau})\odot h_0\]

となります.ただし$\bigodot_{t=0}^T x_t=x_0\odot x_1\odot\cdots\odot x_T$です.

$i_{\tau}\in [0, 1]$とすれば$\bigodot_{\tau=t}^t i_{\tau} \leq \bigodot_{\tau=t+1}^t i_{\tau}$ですから,時間が経ったことほど忘れやすくなります.この$i$も$h_t$の函数として学習し忘却の度合いも調節することで,短期の記憶だけでなく長期の記憶を適度に活かすことが期待されます.

以上の説明は2を参考にしました.

GRU

GRU(Gated Recurrent Unit)はCho et al.(2014)3で提案された手法です.後発にもかかわらず,LSTMと較べてセルが少なくシンプルな構造ですが,特に性能に大きな差はないとされています.上のgated RNNの式にも近く理解しやすいと思います.

\[\begin{aligned} r_t^l &= \mathrm{sigmoid}(W_r^{l}h_t^{l-1}+U_r^lh_{t-1}^l) \cr\cr z_t^l &= \mathrm{sigmoid}(W_z^{l}h_t^{l-1}+U_z^lh_{t-1}^l) \cr\cr n_t^l &= \tanh(W_n^{l}h_t^{l-1}+r_t^l\odot U_n^lh_{t-1}^l) \cr\cr h_t^l &= (1-z_t^l)\odot n_t^l + z_t^l\odot h_{t-1}^l \end{aligned}\]

$r_t,z_t,n_t$は文献によって異なりますが,Cho et al.(2014)ではそれぞれreset gate, update gateおよびcandidate gateと呼ばれています.下に4のGRUのイラストを示しました.sigmoid函数によって0から1の間の値を取るreset gate $r_t$,やupdate gate $z_t$にはスイッチのような役割があることがわかります.

GRUのイラスト.$z,r$はそれぞれupdate gate, reset gate,$\tilde{h}$は我々の記法では$n$のcandidate gateのこと.

LSTM

LSTM(Long Short-Term Memory, 5)は1997年にS.Hochreiterらによって提案されたGated RNN手法の一つで,現在でもよく使われています.

\[\begin{aligned} i_t^l &=\mathrm{sigmoid}(W_i^lh_t^{l-1}+U_i^lh_l^{t-1}) \cr\cr f_t^l &=\mathrm{sigmoid}(W_f^lh_t^{l-1}+U_f^lh_l^{t-1}) \cr\cr g_t^l &=\tanh(W_g^lh_t^{l-1}+U_g^lh_l^{t-1}) \cr\cr o_t^l &=\mathrm{sigmoid}(W_o^lh_t^{l-1}+U_o^lh_l^{t-1}) \cr\cr c_t^l &= f_t^l\odot c_{t-1}^{l}+z_t^l\odot g_t^l \cr\cr h_t^l &= o_t^l\odot\tanh(c_t^l) \end{aligned}\]

GRUと較べるとかなり複雑で,隠れ状態として$h_t$だけでなく,memory cell $c_t$を用います.そのため,はじめに$(h_{-1}, c_{-1})$を用意する必要があります.

$i_t,f_t,g_t,o_t$はそれぞれinput gate, forget gate, cell gateおよびouput gateなどと呼ばれています.

LSTMのイラスト.$i,o,f$はそれぞれinput gate, output gate, forget gates,$\tilde{c}$は我々の記法では$g$のgate gateのこと.

PyTorchにおけるRNN

PyTorchにおけるRNNレイヤーとしてはnn.RNN,nn.LSTM,nn.GRU,Cellとしてはnn.RNNCell,nn.LSTMCell,nn.GRUCellがあります.以下ではそれぞれの使い方を解説します.

課題の解説

今回は3桁の数2つを文字列として受け取り,その数の和を予測するネットワークを例として説明を行います.たとえば入力として”123+234”を受け取った場合” 357”を予想します.簡単のため,ネットワークの入出力の長さは固定します.数字が足りない箇所にはスペース” “を入れることで,入力は7文字,出力は4文字とします.

今回,各数字,”1”や”8”は記号としての意味が重要で,その大小は肝心ではありません.そのため,ネットワークにそのまま数値を入れることは望ましくありません.また,文字を用いたタスクの場合はそもそも数字として表すことができません.そのため,RNNの入力としては入力の値を何らかの規則でベクトルとして表したものを用いることが一般的です.

今回は0から9までの数字と”+“,” “の12種類の記号のみを扱うためonehotと呼ばれる表現を用います.これは記号と対応する位置の要素だけが1,残りが0のようなベクトルで”+“を2番目の要素とすると”010000000000”が”+“に対応するonehot表現です.onehot化した足し算の問題を入力として,onehotの4ベクトルを出力として得て,正解とのcross entropy lossを最小化していきます.つまり

loss = sum(F.cross_entropy(o, t) for o, t in zip(output, target))

です.なお,PyTorchでは目標targetはonehot表現ではなく,インデックスで与えます.

なお簡単のため以下の例ではRNNは1層,全結合層も1層とします.

nn.RNNCell,nn.LSTMCell,nn.GRUCell

Cellは以下の図のcellに相当し,ある時刻での隠れ状態と入力とを受け取り,隠れ状態を返します.

Cellを用いたネットワークの説明

nn.**Cellのパラメーターinput_sizeは入力ベクトルの次元数で,今回はonehotベクトルなので12です.

class Cell(RNNBase):
    def __init__(self, rnn_name, char_num, batch_size, output_size):
        super(Cell, self).__init__(rnn_name, char_num, batch_size, output_size, num_layers=1)

        self.rnn = getattr(nn, self.rnn_name+"Cell")(
            input_size=char_num,
            hidden_size=self.hidden_size)
        for i in range(output_size):  # 出力文字数分の隠れ層を用意します
            setattr(self, f"fc_{i}", nn.Linear(self.hidden_size, char_num))

    def forward(self, x):
        h = self.get_hidden() # 隠れ状態を初期化
        for input in x:  # 入力を一文字ずつ取得します
            h = self.rnn(input, h) # 隠れ状態を更新し,受け継いでいきます

        if self.rnn_name is "LSTM": # LSTMはh,cと2つの隠れ状態を持ちます.ここではhのみを使います
            h = h[0]
        h = F.relu(h)
        output = []
        for i in range(self.output_size):
            output += getattr(self, f"fc_{i}")(h)  # 各隠れ層に隠れ状態を入力として与えます
        return output

nn.RNN,nn.LSTM,nn.GRU

こちらはレイヤーのように扱い,以下の図の緑の長方形に囲まれた部分が相当します.初期の隠れ状態と入力系列全体(図の黄色の長方形)を入力として受け取ります.入力系列は(入力系列の長さ,バッチサイズ,入力ベクトルの次元数)です.パラメーターnum_layersによってRNNを何層重ねるかを決めることができます.

緑の長方形で囲まれた部分がレイヤー.

出力は全時刻における最終層出力outputと,系列の最後の時刻における各層の隠れ状態h_nを返します.今回は最後の時刻の最終層の隠れ状態のみを必要とするので,output[-1]h_n[-1]のどちらも使うことが出来ますが,今回はh_nを用いています.

class RNN(RNNBase):
    def __init__(self, rnn_name, char_num, batch_size, output_size, num_layers):
        super(RNN, self).__init__(rnn_name, char_num, batch_size, output_size, num_layers)

        self.rnn = getattr(nn, self.rnn_name)(
            input_size=char_num,
            hidden_size=self.hidden_size)
        for i in range(output_size):
            setattr(self, f"fc_{i}", nn.Linear(self.hidden_size, char_num))

    def forward(self, x):
        h_0 = self.get_hidden()
        output, h = self.rnn(x, h_0)  
        # h_0を与えない場合,0ベクトルが与えられます.つまり,
        # output, h = self.rnn(x) も動きます.
        if self.rnn_name is "LSTM": # LSTMはh,cと2つの隠れ状態を持ちます.ここではhのみを使います
            h = h[0]
        h = F.relu(h[-1]) # h は(num_layers, batch_size, hidden_size)なので[-1]で最終層の隠れ状態を得ます
        return [getattr(self, f"fc_{i}")(h) for i in range(self.output_size)]

このように,PyTorchではcellとlayerを用いることで柔軟にRNNを用いたネットワークをつくることができます.

今回説明に用いたコードの全体はこちらにあります.54000組の学習によって95%程度の正答率を得ることができました.

なお,今回正答率を求めるに当たって以下のようなコードを用いましたが,はじめtorch.max()の返すインデックスがByteTensorのために正解数をTensorのまま扱うと警告なしにオーバーフローしてしまい,lossは下がれど正答率は上がらず,ということがありました.LongTensorにキャストするか,以下のようにsum()によってintに変換すると良いと思います.

tmp = sum([F.log_softmax(o).data.max(1)[1] == t.data for o, t in zip(output, target)])
count += (tmp == DIGITS+1).sum()

  1. Elman, J. L. (1991). Distributed Representations, Simple Recurrent Networks, And Grammatical Structure. Machine Learning, 7(2), 195–225. http://doi.org/10.1023/A:1022699029236 [return]
  2. 深井裕太・海野裕也・鈴木潤 (2017).『深層学習による自然言語処理』.講談社. [return]
  3. Cho, K., van Merrienboer, B., Bahdanau, D., & Bengio, Y. (2014). On the Properties of Neural Machine Translation: Encoder-Decoder Approaches. http://doi.org/10.3115/v1/W14-4012 [return]
  4. Chung, J., Gulcehre, C., Cho, K., & Bengio, Y. (2014). Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling, 1–9. Retrieved from http://arxiv.org/abs/1412.3555 [return]
  5. Hochreiter, S., & Urgen Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735–1780. http://doi.org/10.1162/neco.1997.9.8.1735 [return]
comments powered by Disqus