以前のブログ記事でStreamlitを使って簡単な機械学習WEBアプリを作るというのをやりましたが、今回はその兄弟企画でPlotly社が提供しているDashというライブラリを使って機械学習WEBアプリを作ってみたいと思います。
Plotly社はPlotlyやPlotly Expressというとても便利な可視化ライブラリを提供しているのですが、そのうちの一つとして簡単にアプリを開発できるDashを提供しています。
DashはPythonベースで開発することができ、同じPythonでアプリ開発するときによく使われるDjangoやFlaskよりもずっと簡単にアプリが作れるので、Pythonで何かデモアプリを作りたいようなときに重宝されます。
そしてStreamlitよりは拡張性が高いので、もう少し凝ったものが作りたい時に向いているのではないかと思います。
というわけで、今回はDashを使った機械学習WEBアプリを作ってみます(デプロイはしません・・)。
Dashで機械学習WEBアプリを作ってみる
今回はSeabornに入っている"tip"チップ額のデータを使って、どのチップ額を予測するWEBアプリを作ってみたいと思います。
データのダウンロード
まずはデータをダウンロードしましょう。
seabornをimportして、load_datasetの中から"tips"を指定してダウンロードします。
1 2 3 4 5 6 7 |
import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns from sklearn.linear_model import LinearRegression df = sns.load_dataset('tips') |
線形重回帰によるモデリング
データがダウンロードできたので、そのデータを使ってチップ額を予測するモデルを作ります。
今回は簡単のため、説明変数を"size", "time", "total_bill"とし、目的変数を"tip"とします。
ちなみに"size"は食事をした人数、"time"はランチかディナーか、"total_bill"は食事の合計額で、"tip"がチップ額になります。
モデリングは簡単のため標準化とかも全くせず、とりあえず線形重回帰でモデリングをするだけにします。
1 2 3 4 5 6 7 8 9 10 |
#線形重回帰による数値予測モデリング use_data = df[['total_bill','size','time','tip']] use_data = pd.get_dummies(use_data, drop_first=True) X = use_data[['total_bill','size','time_Dinner']] Y = use_data[['tip']] clf = LinearRegression() clf.fit(X, Y) |
もちろん、このモデリング自体はこのコード内になく、別のところで作ったモデルを読み込むようにすればよいのですが、一つのファイルにまとまっていた方がわかりやすいかと思い、今回はこのようにしています。
Plotlyによるグラフ作成(箱ひげ図・散布図)
せっかくWEBアプリを作るので、ただ数値を入れて予測するだけでは味気ないですから、アプリで図表も見れるようにしておきたいと思います。
今回は説明変数として3つの変数を使うので、それらの分布をアプリで確認できるようにします。
なので、それ用のグラフをPlotlyで作ります。
make_subplotsで1行3列の箱を用意し、その中にadd_traceして3つのグラフを入れていきます。
入れるのは、「ランチorディナー」と「チップ額」の箱ひげ図、「合計額」と「チップ額」の散布図、そして「人数」と「チップ額」の散布図です。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
#表示するグラフの作成 import plotly.graph_objects as go from plotly.subplots import make_subplots tip_plots = make_subplots(rows=1, cols=3, start_cell='bottom-left') tip_plots.add_trace(go.Box(x=df['time'], y=df['tip'], name='time vs tip'), row=1, col=1) tip_plots.add_trace(go.Scatter(x=df['total_bill'], y=df['tip'], mode='markers', name='total_bill vs tip'), row=1, col=2) tip_plots.add_trace(go.Scatter(x=df['size'], y=df['tip'], mode='markers', name='size vs tip'), row=1, col=3) tip_plots.update_layout( xaxis_title_text='Time (Lunch or Dinner)', yaxis_title_text='Tip ($)', ) tip_plots.update_layout( xaxis2_title_text='Total bill ($)', yaxis2_title_text='Tip ($)', ) tip_plots.update_layout( xaxis3_title_text='Size (人)', yaxis3_title_text='Tip ($)', ) |
まだこの段階ではアプリはできていないので、上記で作ったのはあとでアプリに入れるためのグラフということになります。
アプリの作成(app.layout)
では本命のアプリ部分を作っていきます。
まずはアプリの静的な部分(レイアウト)を作っていき、そのあとにアクションを加えていくというふうに進めていきます。
アプリのレイアウトは、app.layoutという中にHTMLのような形で書いていき、ボタンや数値の入力枠などはdash_core_components (dcc)を使って付け加えていきます。
今回は、数値である「合計額("total bill")」と「人数("size")」を入力し、2値である「ランチorディナー("time")」をラジオボタンを入力するようにします。
そして、そのインプットを読み込んで先ほど作ったモデルで予測を行い、一番最後にチップの予測額を出力します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
###アプリ部分 import dash import dash_html_components as html import dash_core_components as dcc import dash_table from dash.dependencies import Input, Output, State app = dash.Dash() app.layout = html.Div([ html.H1('チップの額を予測するアプリです!', style={"textAlign":"center"}), html.H2('まずはグラフを見てみましょう!'), dcc.Graph( id='graph', figure=tip_plots, style={} ), html.H2('予測用のデータをインプットしてみましょう!'), dcc.Input( id='total_bill', placeholder='total bill ここに値を入れてください', type='text', style={"width":"20%"}, value='' ), dcc.Input( id='size', placeholder='size ここに値を入れてください', type='text', style={"width":"20%"}, value='' ), dcc.RadioItems( id='time', options=[ {'label':'ランチ','value':'Lunch'}, {'label':'ディナー','value':'Dinner'} ], value='Lunch', labelStyle={'display':'inline-block'} ), html.Button( id='submit-button', n_clicks=0, children='Submit' ), html.H2('チップの予測額はいくらかな?'), html.Div( id='output-pred', style={"textAlign":"center","fontSize":30, "color":"red"} ) ]) |
これでアプリのレイアウトができたので、次に動きを付け加えていきます。
アプリの動き(callback)
さて、アプリの動きはcallbackというものを使って実装します。
今はまだ数値を入れる枠やボタンしかないので、これでアプリにしても何も起こりません。
なので必要なことはインプットの値を読み込んだら予測を行い、それを画面に出力するということです。
そちらを@app.callbackという中で指定します。
アクションはInputの"submit-button"が押されたら始まり、それによって次の関数predictionが動作し、そのreturn結果がOutputから返されるという流れになっています。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
@app.callback( Output('output-pred', 'children'), Input('submit-button', 'n_clicks'), [State('total_bill','value'), State('size','value'), State('time','value')] ) def prediction(n_clicks, total_bill, size, time): if time=='Lunch': dinner01 = 0 else: dinner01 = 1 if (total_bill and size): value_df = pd.DataFrame([], columns=['Total bill', 'Size', 'Dinner flag']) record = pd.Series([total_bill, size, dinner01], index=value_df.columns, dtype='float64') value_df = value_df.append(record, ignore_index=True) Y_pred = clf.predict(value_df) return_text = 'チップ額はおそらく'+str('{:.2g}'.format(Y_pred[0,0])+'ドルくらいでしょう!') return return_text else: return 'ちゃんとデータを入力してね!' if __name__ =='__main__': app.run_server() |
そして、最後にapp.run_server()をすれば、アプリが動くということになります。
このコードができたら、ターミナルで以下を実行すればOKです(コードをapp.pyとした場合)。
command
python app.py
するとターミナルにこのようなメッセージが出るので、出てきたURLをブラウザにコピーするとWEBアプリが表示されます。
全体はこのように上部にグラフが、そして下部にインプットとその結果が表示されています。
最初はなにも値が入っていませんので予測額もなく「ちゃんとデータを入力してね」になっています。
次に値を入れてみます。
値を入れましたが、まだSubmitボタンを押していないので、予測額の表示は変わっていないと思います。
そしてSubmitボタンを押すと、callbackが動いてチップの予測額が表示されました!
HTMLとかcallbackとかを頑張って書いていけばけっこう凝ったアプリができますので、ぜひやってみてください!