激安ラジコン(RC)の自動運転化計画※RCをEV3に変更しました

目的:ラジコンの自動運転をすること

使ったもの

ハード

  • ラジコン:レゴ® マインドストーム® EV3
  • ラズベリーパイ3
  • カメラ:LOGICOOL C270
  • ソフト
  • 言語:python
  • DLライブラリ:Keras(on Tensorflow)
  • その他:Opencv,numpy,paho-mqtt...
  • システムの概要

    今回用いたコースはこちら

    今回は言語をpython限定.

    行動の分類を線の数を考慮したクラス分類問題とした f:id:kobakenkken:20180624180334p:plain

     

  • RCの行動の種類
  • 前(forward),右(right),少し右(little right),左(left),少し左(little left),その他(other)の6種類

    画像転送部分(動画の配信)

    • MJPG-streamerを使ってwebカメラから取得した画像をストリーミングを行う.

    *設定

    fps:5

    width:640

    height:480

    *コマンド

    ./mjpg_streamer -i "./input_uvc.so -f 10 -r 320x240 -d /dev/video0 -y -n" -o "./output_http.so -w ./www -p 8080"
    

    モータ制御部分

    • EV3の二つのモータの制御を行う.サーバ(PC)側の分類結果からそれに対応する制御信号をMQTTにより受信し,モータの駆動させる. *プログラム

    学習・検証部分

    • KerasによりCNN部分の実装を行う. *プログラム
    #coding:utf-8
    import os
    from keras.applications.vgg16 import VGG16
    from keras.preprocessing.image import ImageDataGenerator
    from keras.models import Sequential, Model
    from keras.layers import Input, Activation, Dropout, Flatten, Dense
    from keras.preprocessing.image import ImageDataGenerator
    from keras import optimizers
    import numpy as np
    
    classes = ['foward_1', 'foward_2', 'left_1', 'left_2', 'right_1', 'right_2', 'other']
    
    batch_size = 32
    nb_classes = len(classes)
    
    img_rows, img_cols = 150, 150
    channels = 3
    
    train_data_dir = 'data/train'
    validation_data_dir = 'data/test'
    nb_train_samples = 1699
    nb_val_samples = 447
    nb_epoch = 100
    
    result_dir = 'results'
    if not os.path.exists(result_dir):
    os.mkdir(result_dir)
    
    
    if __name__ == '__main__':
    # VGG16モデルと学習済み重みをロード
    # Fully-connected層(FC)はいらないのでinclude_top=False)
    input_tensor = Input(shape=(img_rows, img_cols, 3))
    vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
    # vgg16.summary()
    
    # FC層を構築
    # Flattenへの入力指定はバッチ数を除く
    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16.output_shape[1:]))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(nb_classes, activation='softmax'))
    
    # 学習済みのFC層の重みをロード
    # top_model.load_weights(os.path.join(result_dir, 'bottleneck_fc_model.h5'))
    
    # VGG16とFCを接続
    model = Model(input=vgg16.input, output=top_model(vgg16.output))
    
    # 最後のconv層の直前までの層をfreeze
    for layer in model.layers[:15]:
    layer.trainable = False
    
    # Fine-tuningのときはSGDの方がよい?
    model.compile(loss='categorical_crossentropy',
    optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
    metrics=['accuracy'])
    
    # train_datagen = ImageDataGenerator(featurewise_center=False,
    # samplewise_center=False,
    # featurewise_std_normalization=False,
    # samplewise_std_normalization=False,
    # zca_whitening=False,
    # rotation_range=0.2,
    # width_shift_range=0.2,
    # height_shift_range=0.2,
    # shear_range=0.2,
    # zoom_range=0.2,
    # channel_shift_range=0.1,
    # fill_mode='nearest',
    # cval=0.,
    # horizontal_flip=True,
    # vertical_flip=True,
    # rescale=None)
    
    train_datagen = ImageDataGenerator(featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    rotation_range=0.1,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    channel_shift_range=0.1,
    fill_mode='nearest',
    rescale=None)
    
    test_datagen = ImageDataGenerator()
    
    train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_rows, img_cols),
    color_mode='rgb',
    classes=classes,
    class_mode='categorical',
    batch_size=batch_size,
    shuffle=True)
    
    validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,
    target_size=(img_rows, img_cols),
    color_mode='rgb',
    classes=classes,
    class_mode='categorical',
    batch_size=batch_size,
    shuffle=True)
    
    # Fine-tuning
    history = model.fit_generator(
    train_generator,
    samples_per_epoch=nb_train_samples,
    nb_epoch=nb_epoch,
    validation_data=validation_generator,
    nb_val_samples=nb_val_samples)
    
    model.save_weights(os.path.join(result_dir, '20180802.h5'))
    save_history(history, os.path.join(result_dir, '20180802.txt'))
    
    • 使用したCNNモデル:

     

    走行テスト部分

    *ラズベリーパイ3

    webカメラから動画のストリーミングを行う.

     
    

    *サーバ

    ストリーミングされた動画をキャプチャし,その画像を学習済モデルへ入力し分類を行う.その分類結果をMQTTを用いてモータの制御信号を送信する.

     
    
    #!/usr/bin/env python#!/usr/bin/env python#coding:utf-8import cv2import matplotlib.pyplot as pltfrom IPython import display
    import numpy as npfrom io import BytesIOfrom PIL import Imagefrom PIL import ImageOps
    from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictionsfrom keras.preprocessing import imageimport time
    
    import socketimport numpy as npimport cv2#for predictionimport osimport sysfrom keras.applications.vgg16 import VGG16from keras.models import Sequential, Modelfrom keras.layers import Input, Activation, Dropout, Flatten, Densefrom keras.preprocessing import image
    import paho.mqtt.client as mqttimport threading# Load VGG16#model = VGG16(weights='imagenet')result_dir = 'results'classes = ['foward_1', 'foward_2', 'left_1', 'left_2', 'right_1', 'right_2', 'other']
    nb_classes = len(classes)
    img_height, img_width = 150, 150channels = 3
    pre_count = []
    
    # VGG16input_tensor = Input(shape=(img_height, img_width, channels))vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
    # FCfc = Sequential()fc.add(Flatten(input_shape=vgg16.output_shape[1:]))fc.add(Dense(256, activation='relu'))fc.add(Dropout(0.5))fc.add(Dense(nb_classes, activation='softmax'))
                # VGG16とFCを接続model = Model(input=vgg16.input, output=fc(vgg16.output))
                # 学習済みの重みをロードmodel.load_weights(os.path.join(result_dir, '20180803_3.h5'))
    model.compile(loss='categorical_crossentropy',                          optimizer='adam',                          metrics=['accuracy'])
    URL = "http://192.168.0.4:8080/?action=stream"vc = cv2.VideoCapture(URL)#///////////////////////////////////////////////////////////
    def cap (frame, client):    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)    # makes the blues image look real colored    #webcam_preview.set_data(frame)    plt.draw()    # display.clear_output(wait=True)    # display.display(plt.gcf())    plt.pause(0.01)    img = Image.fromarray(np.uint8(frame))    img = img.resize((150, 150))    x = image.img_to_array(img)    pred_data = np.expand_dims(x, axis=0)    #print("show")    plt.imshow(img)    #plt.show()
        # print("[------------------------ PREDICT ------------------------]\n")    # preds = model.predict(preprocess_input(pred_data))    # #time.sleep(1)    # #print(preds)    # results = decode_predictions(preds, top=1)[0]    # for result in results:    #     #print(result)    #     if results[0] == "[n03814639]"  or "[n03483316]":    #         print("s")    #     print("[{}] {:<30} {}%".format(result[0c], result[1],round(result[2]*100, 2)))    pred = model.predict(pred_data)[0]    #print(type(int(pred.argsort()[-1:][::-1])))    #from pynput.keyboard import Key, Listener    #print("pre_count",len(pre_count))    # host = '192.168.0.18'    # port = 1883    # keepalive = 60    # topic = 'topic/moter/dt'    # client = mqtt.Client()    # client.connect(host, port, keepalive)
        if int(pred.argsort()[-1:][::-1]) == 0:        print('foward_1')        client.publish(topic, str(310) + "," + str(300))        #time.sleep(0.2)        #client.publish(topic, str(0) + "," + str(0))
        elif int(pred.argsort()[-1:][::-1]) == 1:        print('foward_2')        client.publish(topic, str(410) + "," + str(400))        #time.sleep(0.2)        #client.publish(topic, str(0) + "," + str(0))
        elif int(pred.argsort()[-1:][::-1]) == 2:        print('left_1')        client.publish(topic, str(500) + "," + str(200))        #time.sleep(0.2)        # client.publish(topic, str(40) + "," + str(50))        # time.sleep(0.7)        # client.publish(topic, str(60) + "," + str(25))        # time.sleep(0.3)        #client.publish(topic, str(0) + "," + str(0))
        elif int(pred.argsort()[-1:][::-1]) == 3:        print('left_2')        client.publish(topic, str(300) + "," + str(200))        #time.sleep(0.2)        # client.publish(topic, str(40) + "," + str(50))        # time.sleep(0.7)        # client.publish(topic, str(60) + "," + str(25))        # time.sleep(0.3)        #client.publish(topic, str(0) + "," + str(0))
        elif int(pred.argsort()[-1:][::-1]) == 4:        print('right_1')        client.publish(topic, str(200) + "," + str(500))        #time.sleep(0.2)        # client.publish(topic, str(50) + "," + str(40))        # time.sleep(0.7)        # client.publish(topic, str(25) + "," + str(60))        # time.sleep(0.3)        #client.publish(topic, str(0) + "," + str(0))
        elif int(pred.argsort()[-1:][::-1]) == 5:        print('right_2')        client.publish(topic, str(200) + "," + str(300))        #time.sleep(0.2)        # client.publish(topic, str(50) + "," + str(40))        # time.sleep(0.7)        # client.publish(topic, str(25) + "," + str(60))        # time.sleep(0.3)        #client.publish(topic, str(0) + "," + str(0))
        elif int(pred.argsort()[-1:][::-1]) == 6:        print('other')        client.publish(topic, str(-200) + "," + str(300))
        #予測確率が高いトップ5を出力    top = 3    top_indices = pred.argsort()[-top:][::-1]    result = [(classes[i], pred[i]) for i in top_indices]    print(result)    print("")
    
    
    #///////////////////////////////////////////////////////# while True:#     ret, img = vc.read()#     cv2.imshow("Stream Video",img)#     print(img.shape)
    # Capture webcamera photo# vc = cv2.VideoCapture(0)
    
    host = '192.168.0.17'port = 1883keepalive = 60topic = 'topic/motor/dt'client = mqtt.Client()client.connect(host, port, keepalive)
    
    
    if vc.isOpened(): # try to get the first frame    is_capturing, frame = vc.read()    #cv2.imshow("Stream Video",frame)    #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)    # makes the blues image look real colored    #webcam_preview = plt.imshow(frame)    else:    is_capturing = False
    # Push ■ Button!!while is_capturing:    try:    # Lookout for a keyboardInterrupt to stop the script        #print("pre")        # time.sleep(3)        is_capturing, frame = vc.read()        pre_count.append(1)        if int(len(pre_count)) % 5 == 0:            # thread = threading.Thread(target=cap, args=(frame,))            # thread.start()            cap(frame, client)            # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)    # makes the blues image look real colored            # #webcam_preview.set_data(frame)            # plt.draw()            # # display.clear_output(wait=True)            # # display.display(plt.gcf())            # #plt.pause(0.01)            # img = Image.fromarray(np.uint8(frame))            # img = img.resize((150, 150))            # x = image.img_to_array(img)            # pred_data = np.expand_dims(x, axis=0)            # print("show")            # plt.imshow(img)            # # plt.show()
                # # print("[------------------------ PREDICT ------------------------]\n")            # # preds = model.predict(preprocess_input(pred_data))            # # #time.sleep(1)            # # #print(preds)            # # results = decode_predictions(preds, top=1)[0]            # # for result in results:            # #     #print(result)            # #     if results[0] == "[n03814639]"  or "[n03483316]":            # #         print("s")            # #     print("[{}] {:<30} {}%".format(result[0c], result[1],round(result[2]*100, 2)))            # pred = model.predict(pred_data)[0]            # #print(type(int(pred.argsort()[-1:][::-1])))            # #from pynput.keyboard import Key, Listener            # print("pre_count",len(pre_count))            # # host = '192.168.0.18'            # # port = 1883            # # keepalive = 60            # # topic = 'topic/moter/dt'            # # client = mqtt.Client()            # # client.connect(host, port, keepalive)
                # if int(pred.argsort()[-1:][::-1]) == 0:            #     print('foward_1')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(200) + "," + str(200))            #     #time.sleep(0.2)            #     #client.publish(topic, str(0) + "," + str(0))
                # elif int(pred.argsort()[-1:][::-1]) == 1:            #     print('foward_2')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(200) + "," + str(200))            #     #time.sleep(0.2)            #     #client.publish(topic, str(0) + "," + str(0))
                # elif int(pred.argsort()[-1:][::-1]) == 2:            #     print('left_1')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(300) + "," + str(100))            #     #time.sleep(0.2)            #     # client.publish(topic, str(40) + "," + str(50))            #     # time.sleep(0.7)            #     # client.publish(topic, str(60) + "," + str(25))            #     # time.sleep(0.3)            #     #client.publish(topic, str(0) + "," + str(0))
                # elif int(pred.argsort()[-1:][::-1]) == 3:            #     print('left_2')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(200) + "," + str(50))            #     #time.sleep(0.2)            #     # client.publish(topic, str(40) + "," + str(50))            #     # time.sleep(0.7)            #     # client.publish(topic, str(60) + "," + str(25))            #     # time.sleep(0.3)            #     #client.publish(topic, str(0) + "," + str(0))
                # elif int(pred.argsort()[-1:][::-1]) == 4:            #     print('right_1')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(100) + "," + str(300))            #     #time.sleep(0.2)            #     # client.publish(topic, str(50) + "," + str(40))            #     # time.sleep(0.7)            #     # client.publish(topic, str(25) + "," + str(60))            #     # time.sleep(0.3)            #     #client.publish(topic, str(0) + "," + str(0))
                # elif int(pred.argsort()[-1:][::-1]) == 5:            #     print('right_2')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(50) + "," + str(200))            #     #time.sleep(0.2)            #     # client.publish(topic, str(50) + "," + str(40))            #     # time.sleep(0.7)            #     # client.publish(topic, str(25) + "," + str(60))            #     # time.sleep(0.3)            #     #client.publish(topic, str(0) + "," + str(0))
                # elif int(pred.argsort()[-1:][::-1]) == 6:            #     print('other')            #     host = '192.168.0.17'            #     port = 1883            #     keepalive = 60            #     topic = 'topic/motor/dt'            #     client = mqtt.Client()            #     client.connect(host, port, keepalive)            #     client.publish(topic, str(200) + "," + str(50))
    
    
                #import paho.mqtt.client as mqtt            #from pynput.keyboard import Key, Listener
                # host = '192.168.0.7'            # port = 1883            # keepalive = 60            # topic = 'mqtt/test'            # client = mqtt.Client()            # client.connect(host, port, keepalive)
                # if classes[i] == 'foward':            #     print('foward')            #     #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0))            # elif classes[i] == 'right':            #     print('right')            #     #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0))            # elif classes[i] == 'left':            #     print('left')                #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0))                #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0))
    
            # the pause time is = 1 / frameratef
        except KeyboardInterrupt:        vc.release()        is_capturing = False
    

    *EV3

    受信した制御信号をモータへ反映し,モータを駆動させ制御を行う.

     
    
    # !/usr/bin/env python3
    import paho.mqtt.client as mqtt
    from ev3dev.auto import *
    
    count = []
    
    ma = Motor('outA')
    md = Motor('outD')
    
    def on_connect(client, userdata, flags, rc):
    print("Connected with result code " + str(rc))
    client.subscribe("topic/motor/dt")
    
    
    def on_message(client, userdata, msg):
    msg_str = msg.payload.decode("utf-8")
    msg_array = msg_str.split(",")
    ma.speed_sp = msg_array[0]
    md.speed_sp = msg_array[1]
    count.append(1)
    print(len(count))
    print(msg_array)
    #ma.duty_cycle_sp = msg_array[0]
    #time.sleep(3)
    #ma.stop()
    #md.duty_cycle_sp = msg_array[1]
    # time.sleep(3)
    #ma.stop()
    ma.run_timed(time_sp=250,speed_sp=ma.speed_sp,stop_action='brake')
    md.run_timed(time_sp=250,speed_sp=md.speed_sp,stop_action='brake')
    #time.sleep(1)
    
    client = mqtt.Client()
    client.connect("192.168.0.17" ,1883 ,60)
    
    client.on_connect = on_connect
    client.on_message = on_message
    
    ma.run_direct()
    md.run_direct()
    ma.duty_cycle_sp = 0
    md.duty_cycle_sp = 0
    
    client.loop_forever()
    

    実際の走行動画

    結果・考察