PyTorchでCNN入門

Jun 10, 2017   #PyTorch  #Python 

CNNの概説

CNNは畳み込みニューラルネットワーク(convolutional neural network)の略です.CNNは四天王のひとりLeCun(1989)に始まり,2012年の一般物体認識のコンテスト(ILSVRC)で優勝しディープラーニングを一躍有名にしたAlexNet(Krizhevsky)を経て,現在の画像認識には欠かせないネットワークです.

畳み込み

CNNでは畳み込み(convolution)という操作を行います.ここでは簡単のためにすべて2次元で考えます.

以下のように入力の行列とフィルタが与えられたときに,

\[I=ar+bs+tc+du+ev+fw+gz+hy+iz\]

を畳み込みと呼びます.本来画像認識の分野では$ax+by+cz+\cdots$を畳み込みと呼び,上記の演算は相関と呼ばれるようですが,CNNの文脈ではこれを畳み込みと呼ぶよう1なので慣例に倣います.

左が入力の一部,右がフィルター.

入力に対して,この操作を同じフィルタをずらしながら適用していきます.下の図では上部の入力とフィルタの畳み込み結果を下の出力行列の各要素にする様子を書きました.こうして畳み込みによる出力が得られます.

上が入力とフィルタ,下が出力.

畳み込みは画像の対応部分とフィルタとの内積を取ることですから,それらの関連ぐらいを見ていることになります(それ故に相関と呼ばれるのですが).従って,出力は入力画像のフィルタとの関連度を凝縮したものになるわけです.以上の畳み込み(あるいは相関)自体はCNN以前から画像認識の分野で用いられてきましたが,CNNではフィルタ自体を誤差逆伝播法で学習していく点が従来とは異なります2

上では入力,フィルタとも1枚ずつである場合を考えましたが,一般にそれらは複数枚あり,テンソルとして扱われます.この「枚数方向」の次元はチャネルと呼ばれます.特にRGB画像は3チャネルです.

複数チャネルの場合は,入力の各チャネルに対して同一のフィルタを適用し,その和をとります.従って,出力のチャネル数はフィルタ数と一致します.

プーリング

CNNではその他にプーリングという操作を行う場合もあります.その中でもよく用いられる最大プーリング(max pooling)は下に示したように,領域内の最大値を取り出して出力とする操作です.画像認識では位置がずれた同じ物体も同じものとして認識したいので,この操作を加えて位置に対する不変性を向上させます.

最大プーリング.上部が入力で下部が出力.

最大プーリングのほかに,平均値を用いるプーリングもあります.

用語

説明に用いる画像はこちらのもので,今までと異なり下が入力,上が出力です.

kernel

上記の畳み込みのフィルタやプーリングの領域のことをカーネルと呼ぶこともあります.

stride

カーネルの動く際のステップです.プーリングの場合は領域幅と同じ幅で動かし,重複する範囲がないようにすることが多い気がします.

stride=1

padding

畳み込み,プーリングを上記のように行った場合,出力は入力よりも小さくなります.入力の周りに「枠」を付けることで出力サイズを調整するのがpaddingです.「枠」を0で埋めるゼロパディングがしばしば用いられます.

stride=1,padding=1

dilation

カーネルにあける隙間の大きさです.プーリングの代わりにdilationを用いることもあるようです.

stride=1,dilation=1

relu

活性化函数の一つで,“rectified linear unit”の略です.函数としては

\[\mathrm{relu}(x)=\max(0,x)\]

と極めて単純ですが,これがなければ現在のディープニューラルネットワーク時代はなかった,とも言えるような,強力な存在です.以前はsigmoid函数(S字状函数),たとえば

\[\mathrm{sigmoid}(x)=\frac{1}{1+e^{-x}}\]

が用いられていましたが,ネットワークが深くなると勾配が消失する問題を抱えていました.

sigmoid函数とrelu函数との比較

PyTorchにおけるCNN

PyTorchの簡単なチュートリアルはこちらにあります.

コード中のFnn.functionalのことです.

nn

nn.Conv2dを用います.F.conv2dというものもありますが,こちらは自分で明示的にweight,biasのテンソルを用意し,必要であれば重みを更新しなくてはいけません.他方,nn.Conv2dであれば入出力のチャネル数およびカーネルの大きさを定めるだけです.今回は用いていませんが,上で説明したpadding,dilationを用いることもできます.

プーリングには,畳み込みのように更新すべきテンソルがないのでF.max_pool2dを用いても差はありません.

class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1,
                               out_channels=10,
                               kernel_size=5,
                               stride=1)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.dense1 = nn.Linear(in_features=320,
                                out_features=50)
        self.dense2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, kernel_size=2)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = F.relu(x)
        x = x.view(-1, 320)
        x = self.dense1(x)
        x = F.relu(x)
        x = self.dense2(x)
        return F.log_softmax(x)

一方で,モデルの一部をほかのモデルに流用したい場合などは以下のように書くと便利かもしれません.

# alternative way
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.head = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10,
                      kernel_size=5, stride=1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(10, 20, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU())
        self.tail = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Linear(50, 10))

    def forward(self, x):
        x = self.head(x)
        x = x.view(-1, 320)
        x = self.tail(x)
        return F.log_softmax(x)

こうすれば,この分類のモデルによってアルファベットの分類を行いたい場合,

class Net3(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        net2 = Net2()
        self.head = net2.head
        self.tail = nn.Sequential(
            nn.Linear(320, 50),
            nn.ReLU(),
            nn.Linear(50, 52))    

と表せるので簡便です.

その他

強力なモデルを使う

PyTorchの便利ツールtorchvisionには強力なモデルであるAlexNet,VGG,ResNet,SqueezeNet,DenseNetが用意されています.特にAlexNet,VGGおよびResNetはImageNet2012での学習済みの重みをダウンロードすることができます.

from torchvision import models
alexnet = models.AlexNet(pretrained=True)
output = alexnet(input)

また,この重みを初期値として,少ないデータ量でも学習を行うことがあります(fine-tuning).その場合は先ほど紹介したような方法で,classifier部をつなぎ替えることができます.

class Net3(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.head = alexnet.features
        self.tail = nn.Sequential(
            ...)    

重みの初期化

CNNにおいては重みの初期値が重要でさまざまな初期値が提案されています[Glorot et al. 2010, He et al. 2015].PyTorchではnn.initにさまざまな初期化方法が収録されています.Xavier initializationであれば以下によって初期化できます.

from torch.nn import init
 torch.nn.init.xavier_normal(alexnet.features[0].weight)

重みの保存,読み込み

保存は

torch.save({"state_dict": model.state_dict(),
            "epoch":...},
            filename)

です.読み込む際には

model = Model()
weights = torch.load(filename)
model.load_state_dict(weights["state_dict"])

です.一部のみを読み込む場合,例えばresnetの最終出力層resnet.fcの重みをFully Convolutionのカーネルfcn.convに入れたい場合は

res_weight = resnet.fc.state_dict()
fcn.conv.load_state_dict({"weight": res_weight["weight"].view(1000, 2048, 1, 1),
                          "bias": res_weight["bias"]})

となります.


  1. 岡谷「深層学習」2015,原田「画像認識」2017 [return]
  2. 筆者は深層学習以前を知らないので実はよく分かりません. [return]