scikit-learnで標準化,正規化
標準化の式
正規化の式
scikit-learn で
sklearn の StandardScaler と MinMaxScaler がそれぞれ 標準化 と 正規化 のモジュールです。主に使うメソッドは次の 3 つです。
fit
パラメータ(平均や標準偏差 etc)計算
transform
パラメータをもとにデータ変換
fit_transform
パラメータ計算とデータ変換をまとめて実行
赤:original data
青:標準化データ(standard)
緑:正規化データ(normalized)
code
# coding:utf-8 import matplotlib.pyplot as plt import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import MinMaxScaler # 元データ np.random.seed(seed=1) data = np.random.multivariate_normal( [7, 5], [[5, 0],[0, 2]], 100 ) # 標準化 sc = StandardScaler() data_std = sc.fit_transform(data) # 正規化 ms = MinMaxScaler() data_norm = ms.fit_transform(data) # プロット min_x = min(data[:,0]) max_x = max(data[:,0]) min_y = min(data[:,1]) max_y = max(data[:,1]) plt.figure(figsize=(5, 6)) plt.subplot(2,1,1) plt.title('StandardScaler') plt.xlim([-4, 10]) plt.ylim([-4, 10]) plt.scatter(data[:, 0], data[:, 1], c='red', marker='x', s=30, label='origin') plt.scatter(data_std[:, 0], data_std[:, 1], c='blue', marker='x', s=30, label='standard ') plt.legend(loc='upper left') plt.hlines(0,xmin=-4, xmax=10, colors='#888888', linestyles='dotted') plt.vlines(0,ymin=-4, ymax=10, colors='#888888', linestyles='dotted') plt.subplot(2,1,2) plt.title('MinMaxScaler') plt.xlim([-4, 10]) plt.ylim([-4, 10]) plt.scatter(data[:, 0], data[:, 1], c='red', marker='x', s=30, label='origin') plt.scatter(data_norm[:, 0], data_norm[:, 1], c='green', marker='x', s=30, label='normalize') plt.legend(loc='upper left') plt.hlines(0,xmin=-4, xmax=10, colors='#888888', linestyles='dotted') plt.vlines(0,ymin=-4, ymax=10, colors='#888888', linestyles='dotted') plt.show()