コーディングやってる場合じゃねぇ

機械学習とコーディングと備忘録

【Python】Python+FlaskでEfficientGANによる異常検知の実装

GW中に書籍ならびに他ブログを参考にし
Python+FlaskでEfficientGANを用いた異常検知のwebアプリを作ってみた.

このブログでは自分の備忘録用であり,GANの仕組みや実装コードは省略.
コードについては参考書籍である
「つくりながら学ぶ!Pytorchによる発展ディープラーニング」をぜひ購入してください.
6-3章で紹介しており,わかりやすい.

概要

GANの一種であるEfficientGANを用いて
MNISTの7の画像を学習する.

その後,MNISTの7と8の画像を入れた際の異常度スコアを計算し,
それぞれの異常度スコアの分布から閾値を決定する.

最後にテストのデータを入れた際,
その画像が7か8かを判別する.

その際,web上で画像をアップロードできるようにFlaskを使用した.
※なお今回のはlocal hostで試行

必要データ・ディレクトリ構成

MNISTデータ

必要なデータはMNISTデータである.
私は参考書籍のサンプルファイルに記載されている
scikit-learnのデータセットを使用した.

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, data_home="./data/") 
X = mnist.data
y = mnist.target

これでXに画像データ情報が,yにクラス情報が入っている.
あとはこれを画像としてPillowで保存するだけでよい.

面倒な場合,KaggleのMNIST datasetをダウンロードすればよい.
ただし,すべての数字がごちゃ混ぜに入っているので注意.

www.kaggle.com

EfficientGAN学習用構成

続いて,EfficientGANのモデル学習のためのプログラムの
ディレクトリ構成は画像の通り. f:id:ararabo7:20200505211454p:plain プログラムファイルと同じディレクトリにdataディレクトリを置き
その下に学習データ等を置くことにした.

異常検知アプリ用の構成

最後にFlaskを用いたアプリ側の構成がこちら

f:id:ararabo7:20200505224040p:plain
Anomaly Detect App

  • templatesディレクトリ内にhtmlファイルを保存
  • staticディレクトリにアップロードした画像ファイルを保存
    ※ただし,コードの関係上dummyのファイル(MNISTの画像1枚)あり
  • modelディレクトリに学習したモデルを保存

EfficientGANを用いた異常検知

まず,EfficientGANのモデルにMNISTの7の画像を学習させる.

続いて,「つくりながら学ぶ!Pytorchによる発展ディープラーニング」内では
EffecientGANの損失関数を異常検知スコアとしている.

具体的には

  1. 入力したテスト画像とテスト画像と最もよく似た生成画像を生み出す生成乱数zを求める
  2. テスト画像と生成乱数zから生み出した生成画像のチャネル毎のピクセル毎の差を求め,
    絶対値のピクセル和を損失(residual loss)を計算する
  3. テスト画像や生成した画像ならびに使用した生成乱数zをDescriminatorに入力し,その全結合層1つ手前の
    特徴量を損失として利用する(dicrimination lossと呼んでいる)
  4. 2,3の損失の重み付平均をlossとして計算し,異常度スコアとする

となっています.
今回,この異常度スコアをMNISTデータの7,8の画像で計算し,
その分布の差から閾値を計算した.

こちらが7の異常度スコアの結果(ヒストグラム)

f:id:ararabo7:20200505221106p:plain
AnomalyScorefor7
こちらが8の異常度スコアの結果(ヒストグラム)
f:id:ararabo7:20200505221148p:plain
AnomalyScorefor8
一応,150~170あたりが7と8の境界線のように見える.

モデルの保存

学習させたモデルを保存しておかなければならない.
cloudpickleモジュールを用いて保存した.
今回はモデルが3つ(Generator, Discriminator, Encoder)があるので全て保存すること. (Generator:G_model, Discriminator: D_model, Encoder:E_model)

with open('G_model.pkl', 'wb') as f:
    cloudpickle.dump(G_model, f)
with open('D_model.pkl', 'wb') as f:
    cloudpickle.dump(D_model, f)
with open('E_model.pkl', 'wb') as f:
    cloudpickle.dump(E_model, f)

Flaskによるweb上での実装

Python用のWebフレームワークであるFlaskを用いることで簡単にwebアプリが作成できます.
参考となったブログはこちら Python × Flask × PyTorch 数字認識Webアプリのお手軽構築 - Qiita

このブログで記載されているコードを改造しました.

# Flask関連
from flask import Flask, render_template, request, redirect, url_for, abort
import numpy as np
# PyTorch関連
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
import torch.utils.data as data
import torch.optim as optim

# Pillow(PIL)、datetime
from PIL import Image, ImageOps
from datetime import datetime
class GAN_Img_Dataset(data.Dataset):
    """画像のDatasetクラス。PyTorchのDatasetクラスを継承"""

    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        '''画像の枚数を返す'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''前処理をした画像のTensor形式のデータを取得'''

        img_path = self.file_list[index]
        img = Image.open(img_path)  # [高さ][幅]白黒

        # 画像の前処理
        img_transformed = self.transform(img)

        return img_transformed
class ImageTransform():
    """画像の前処理クラス"""

    def __init__(self, mean, std):
        self.data_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    def __call__(self, img):
        return self.data_transform(img)

# GPU or CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#Generatorのロード
import cloudpickle
with open('.model/G_model.pkl', 'rb') as f:
    G_model = cloudpickle.load(f)
#Discriminatorのロード
with open('.model/D_model.pkl', 'rb') as f:
    D_model = cloudpickle.load(f)
#Encoderのロード
with open('.model/E_model.pkl', 'rb') as f:
    E_model = cloudpickle.load(f)

app = Flask(__name__)


def Anomaly_score(x, fake_img, z_out_real, D, Lambda=0.1):
    
    # テスト画像xと生成画像fake_imgのピクセルレベルの差の絶対値を求めて、ミニバッチごとに和を求める
    residual_loss = torch.abs(x-fake_img)
    residual_loss = residual_loss.view(residual_loss.size()[0], -1)
    residual_loss = torch.sum(residual_loss, dim=1)

    # テスト画像xと生成画像fake_imgを識別器Dに入力し、特徴量マップを取り出す

    _, x_feature = D(x, z_out_real)
    _, G_feature = D(fake_img, z_out_real)

    # テスト画像xと生成画像fake_imgの特徴量の差の絶対値を求めて、ミニバッチごとに和を求める
    discrimination_loss = torch.abs(x_feature-G_feature)
    discrimination_loss = discrimination_loss.view(
        discrimination_loss.size()[0], -1)
    discrimination_loss = torch.sum(discrimination_loss, dim=1)

    # ミニバッチごとに2種類の損失を足し算する
    loss_each = (1-Lambda)*residual_loss + Lambda*discrimination_loss

    # ミニバッチ全部の損失を求める
    total_loss = torch.sum(loss_each)

    return total_loss, loss_each, residual_loss

@app.route("/", methods=["GET", "POST"])
def upload_file():
    if request.method == "GET":
        return render_template("index.html")
    if request.method == "POST":
        # アプロードされたファイルをいったん保存する
        f = request.files["file"]
        filepath = "./static/" + datetime.now().strftime("%Y%m%d%H%M%S") + ".jpg"
        f.save(filepath)
        # 保存したファイル+dummyファイルのリスト作成
        predic_img_list=list()
        predic_img_list.append(filepath)
        dummy_filepath="./static/" + "dummy" + ".jpg"
        predic_img_list.append(dummy_filepath)
        # Datasetを作成
        mean = (0.5,)
        std = (0.5,)
        test_dataset = GAN_Img_Dataset(
            file_list=predic_img_list, transform=ImageTransform(mean, std))
        # DataLoaderを作成 モデルの関係上,バッチサイズが2必要なので・・・
        batch_size = 2


        test_dataloader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False)
        # 動作の確認
        batch_iterator = iter(test_dataloader)  # イテレータに変換
        imges = next(batch_iterator)  # 1番目の要素を取り出す
        x=imges[0:2]
        x = x.to(device)

        # 教師データの画像をエンコードしてzにしてから、Gで生成
        z_out_real = E_model(imges.to(device))
        imges_reconstract = G_model(z_out_real)

        # 異常度スコアを求める
        loss, loss_each, residual_loss_each = Anomaly_score(
            x, imges_reconstract, z_out_real, D_model, Lambda=0.1)

        # 損失の計算。トータルの損失
        loss_each = loss_each.cpu().detach().numpy()
        # 1つ目のスコア(アップロードした画像の方のスコア)をresultへ
        result=np.round(loss_each[0], 0)
        if(result<175):
            strResultMsg=7
        else:
            strResultMsg=8

        return render_template("index.html", filepath=filepath, result=strResultMsg)


if __name__ == "__main__":
    app.run(debug=True)

学習時に使用したモデルのコード上,Batch Normalizationを行っており, バッチのサイズが2以上ないといけないらしく, 今回推論時にダミーで1つファイルを追加している. ※要改善点

このコードと参考ブログのhtmlファイルで

f:id:ararabo7:20200505225902p:plain
Calc Result(EfficientGAN)
のようにアップロードした画像に対して7なのか8なのかを判別することが可能に. 実際は異常検知なので7以外は全て異常として運用しないといけない

References

つくりながら学ぶ! PyTorchによる発展ディープラーニング

Python × Flask × PyTorch 数字認識Webアプリのお手軽構築 - Qiita

Houssam Zenati, Chuan Sheng Foo, Bruno Lecouat, Gaurav Manek, Vijay Ramaseshan Chandrasekhar, "Efficient GAN", https://arxiv.org/abs/1802.06222 (2017).