【Python】Python+FlaskでEfficientGANによる異常検知の実装
GW中に書籍ならびに他ブログを参考にし
Python+FlaskでEfficientGANを用いた異常検知のwebアプリを作ってみた.
このブログでは自分の備忘録用であり,GANの仕組みや実装コードは省略.
コードについては参考書籍である
「つくりながら学ぶ!Pytorchによる発展ディープラーニング」をぜひ購入してください.
6-3章で紹介しており,わかりやすい.
概要
GANの一種であるEfficientGANを用いて
MNISTの7の画像を学習する.
その後,MNISTの7と8の画像を入れた際の異常度スコアを計算し,
それぞれの異常度スコアの分布から閾値を決定する.
最後にテストのデータを入れた際,
その画像が7か8かを判別する.
その際,web上で画像をアップロードできるようにFlaskを使用した.
※なお今回のはlocal hostで試行
必要データ・ディレクトリ構成
MNISTデータ
必要なデータはMNISTデータである.
私は参考書籍のサンプルファイルに記載されている
scikit-learnのデータセットを使用した.
from sklearn.datasets import fetch_openml mnist = fetch_openml('mnist_784', version=1, data_home="./data/") X = mnist.data y = mnist.target
これでXに画像データ情報が,yにクラス情報が入っている.
あとはこれを画像としてPillowで保存するだけでよい.
面倒な場合,KaggleのMNIST datasetをダウンロードすればよい.
ただし,すべての数字がごちゃ混ぜに入っているので注意.
EfficientGAN学習用構成
続いて,EfficientGANのモデル学習のためのプログラムの
ディレクトリ構成は画像の通り.
プログラムファイルと同じディレクトリにdataディレクトリを置き
その下に学習データ等を置くことにした.
異常検知アプリ用の構成
最後にFlaskを用いたアプリ側の構成がこちら
- templatesディレクトリ内にhtmlファイルを保存
- staticディレクトリにアップロードした画像ファイルを保存
※ただし,コードの関係上dummyのファイル(MNISTの画像1枚)あり - modelディレクトリに学習したモデルを保存
EfficientGANを用いた異常検知
まず,EfficientGANのモデルにMNISTの7の画像を学習させる.
続いて,「つくりながら学ぶ!Pytorchによる発展ディープラーニング」内では
EffecientGANの損失関数を異常検知スコアとしている.
具体的には
- 入力したテスト画像とテスト画像と最もよく似た生成画像を生み出す生成乱数zを求める
- テスト画像と生成乱数zから生み出した生成画像のチャネル毎のピクセル毎の差を求め,
絶対値のピクセル和を損失(residual loss)を計算する - テスト画像や生成した画像ならびに使用した生成乱数zをDescriminatorに入力し,その全結合層1つ手前の
特徴量を損失として利用する(dicrimination lossと呼んでいる) - 2,3の損失の重み付平均をlossとして計算し,異常度スコアとする
となっています.
今回,この異常度スコアをMNISTデータの7,8の画像で計算し,
その分布の差から閾値を計算した.
こちらが7の異常度スコアの結果(ヒストグラム) こちらが8の異常度スコアの結果(ヒストグラム) 一応,150~170あたりが7と8の境界線のように見える.
モデルの保存
学習させたモデルを保存しておかなければならない.
cloudpickleモジュールを用いて保存した.
今回はモデルが3つ(Generator, Discriminator, Encoder)があるので全て保存すること.
(Generator:G_model, Discriminator: D_model, Encoder:E_model)
with open('G_model.pkl', 'wb') as f: cloudpickle.dump(G_model, f) with open('D_model.pkl', 'wb') as f: cloudpickle.dump(D_model, f) with open('E_model.pkl', 'wb') as f: cloudpickle.dump(E_model, f)
Flaskによるweb上での実装
Python用のWebフレームワークであるFlaskを用いることで簡単にwebアプリが作成できます.
参考となったブログはこちら
Python × Flask × PyTorch 数字認識Webアプリのお手軽構築 - Qiita
このブログで記載されているコードを改造しました.
# Flask関連 from flask import Flask, render_template, request, redirect, url_for, abort import numpy as np # PyTorch関連 import torch import torch.nn as nn import torch.nn.functional as F import torchvision from torchvision import datasets, transforms import torch.utils.data as data import torch.optim as optim # Pillow(PIL)、datetime from PIL import Image, ImageOps from datetime import datetime class GAN_Img_Dataset(data.Dataset): """画像のDatasetクラス。PyTorchのDatasetクラスを継承""" def __init__(self, file_list, transform): self.file_list = file_list self.transform = transform def __len__(self): '''画像の枚数を返す''' return len(self.file_list) def __getitem__(self, index): '''前処理をした画像のTensor形式のデータを取得''' img_path = self.file_list[index] img = Image.open(img_path) # [高さ][幅]白黒 # 画像の前処理 img_transformed = self.transform(img) return img_transformed class ImageTransform(): """画像の前処理クラス""" def __init__(self, mean, std): self.data_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) def __call__(self, img): return self.data_transform(img) # GPU or CPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #Generatorのロード import cloudpickle with open('.model/G_model.pkl', 'rb') as f: G_model = cloudpickle.load(f) #Discriminatorのロード with open('.model/D_model.pkl', 'rb') as f: D_model = cloudpickle.load(f) #Encoderのロード with open('.model/E_model.pkl', 'rb') as f: E_model = cloudpickle.load(f) app = Flask(__name__) def Anomaly_score(x, fake_img, z_out_real, D, Lambda=0.1): # テスト画像xと生成画像fake_imgのピクセルレベルの差の絶対値を求めて、ミニバッチごとに和を求める residual_loss = torch.abs(x-fake_img) residual_loss = residual_loss.view(residual_loss.size()[0], -1) residual_loss = torch.sum(residual_loss, dim=1) # テスト画像xと生成画像fake_imgを識別器Dに入力し、特徴量マップを取り出す _, x_feature = D(x, z_out_real) _, G_feature = D(fake_img, z_out_real) # テスト画像xと生成画像fake_imgの特徴量の差の絶対値を求めて、ミニバッチごとに和を求める discrimination_loss = torch.abs(x_feature-G_feature) discrimination_loss = discrimination_loss.view( discrimination_loss.size()[0], -1) discrimination_loss = torch.sum(discrimination_loss, dim=1) # ミニバッチごとに2種類の損失を足し算する loss_each = (1-Lambda)*residual_loss + Lambda*discrimination_loss # ミニバッチ全部の損失を求める total_loss = torch.sum(loss_each) return total_loss, loss_each, residual_loss @app.route("/", methods=["GET", "POST"]) def upload_file(): if request.method == "GET": return render_template("index.html") if request.method == "POST": # アプロードされたファイルをいったん保存する f = request.files["file"] filepath = "./static/" + datetime.now().strftime("%Y%m%d%H%M%S") + ".jpg" f.save(filepath) # 保存したファイル+dummyファイルのリスト作成 predic_img_list=list() predic_img_list.append(filepath) dummy_filepath="./static/" + "dummy" + ".jpg" predic_img_list.append(dummy_filepath) # Datasetを作成 mean = (0.5,) std = (0.5,) test_dataset = GAN_Img_Dataset( file_list=predic_img_list, transform=ImageTransform(mean, std)) # DataLoaderを作成 モデルの関係上,バッチサイズが2必要なので・・・ batch_size = 2 test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=batch_size, shuffle=False) # 動作の確認 batch_iterator = iter(test_dataloader) # イテレータに変換 imges = next(batch_iterator) # 1番目の要素を取り出す x=imges[0:2] x = x.to(device) # 教師データの画像をエンコードしてzにしてから、Gで生成 z_out_real = E_model(imges.to(device)) imges_reconstract = G_model(z_out_real) # 異常度スコアを求める loss, loss_each, residual_loss_each = Anomaly_score( x, imges_reconstract, z_out_real, D_model, Lambda=0.1) # 損失の計算。トータルの損失 loss_each = loss_each.cpu().detach().numpy() # 1つ目のスコア(アップロードした画像の方のスコア)をresultへ result=np.round(loss_each[0], 0) if(result<175): strResultMsg=7 else: strResultMsg=8 return render_template("index.html", filepath=filepath, result=strResultMsg) if __name__ == "__main__": app.run(debug=True)
学習時に使用したモデルのコード上,Batch Normalizationを行っており, バッチのサイズが2以上ないといけないらしく, 今回推論時にダミーで1つファイルを追加している. ※要改善点
このコードと参考ブログのhtmlファイルで のようにアップロードした画像に対して7なのか8なのかを判別することが可能に. 実際は異常検知なので7以外は全て異常として運用しないといけない
References
つくりながら学ぶ! PyTorchによる発展ディープラーニング
Python × Flask × PyTorch 数字認識Webアプリのお手軽構築 - Qiita
Houssam Zenati, Chuan Sheng Foo, Bruno Lecouat, Gaurav Manek, Vijay Ramaseshan Chandrasekhar, "Efficient GAN", https://arxiv.org/abs/1802.06222 (2017).