激安ラジコン(RC)の自動運転化計画※RCをEV3に変更しました
目的:ラジコンの自動運転をすること
使ったもの
ハード
システムの概要
今回用いたコースはこちら
今回は言語をpython限定.
行動の分類を線の数を考慮したクラス分類問題とした
画像転送部分(動画の配信)
- MJPG-streamerを使ってwebカメラから取得した画像をストリーミングを行う.
*設定
fps:5
width:640
height:480
*コマンド
./mjpg_streamer -i "./input_uvc.so -f 10 -r 320x240 -d /dev/video0 -y -n" -o "./output_http.so -w ./www -p 8080"
モータ制御部分
- EV3の二つのモータの制御を行う.サーバ(PC)側の分類結果からそれに対応する制御信号をMQTTにより受信し,モータの駆動させる. *プログラム
学習・検証部分
- KerasによりCNN部分の実装を行う. *プログラム
#coding:utf-8 import os from keras.applications.vgg16 import VGG16 from keras.preprocessing.image import ImageDataGenerator from keras.models import Sequential, Model from keras.layers import Input, Activation, Dropout, Flatten, Dense from keras.preprocessing.image import ImageDataGenerator from keras import optimizers import numpy as np classes = ['foward_1', 'foward_2', 'left_1', 'left_2', 'right_1', 'right_2', 'other'] batch_size = 32 nb_classes = len(classes) img_rows, img_cols = 150, 150 channels = 3 train_data_dir = 'data/train' validation_data_dir = 'data/test' nb_train_samples = 1699 nb_val_samples = 447 nb_epoch = 100 result_dir = 'results' if not os.path.exists(result_dir): os.mkdir(result_dir) if __name__ == '__main__': # VGG16モデルと学習済み重みをロード # Fully-connected層(FC)はいらないのでinclude_top=False) input_tensor = Input(shape=(img_rows, img_cols, 3)) vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor) # vgg16.summary() # FC層を構築 # Flattenへの入力指定はバッチ数を除く top_model = Sequential() top_model.add(Flatten(input_shape=vgg16.output_shape[1:])) top_model.add(Dense(256, activation='relu')) top_model.add(Dropout(0.5)) top_model.add(Dense(nb_classes, activation='softmax')) # 学習済みのFC層の重みをロード # top_model.load_weights(os.path.join(result_dir, 'bottleneck_fc_model.h5')) # VGG16とFCを接続 model = Model(input=vgg16.input, output=top_model(vgg16.output)) # 最後のconv層の直前までの層をfreeze for layer in model.layers[:15]: layer.trainable = False # Fine-tuningのときはSGDの方がよい? model.compile(loss='categorical_crossentropy', optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), metrics=['accuracy']) # train_datagen = ImageDataGenerator(featurewise_center=False, # samplewise_center=False, # featurewise_std_normalization=False, # samplewise_std_normalization=False, # zca_whitening=False, # rotation_range=0.2, # width_shift_range=0.2, # height_shift_range=0.2, # shear_range=0.2, # zoom_range=0.2, # channel_shift_range=0.1, # fill_mode='nearest', # cval=0., # horizontal_flip=True, # vertical_flip=True, # rescale=None) train_datagen = ImageDataGenerator(featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, samplewise_std_normalization=False, zca_whitening=False, rotation_range=0.1, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, channel_shift_range=0.1, fill_mode='nearest', rescale=None) test_datagen = ImageDataGenerator() train_generator = train_datagen.flow_from_directory( train_data_dir, target_size=(img_rows, img_cols), color_mode='rgb', classes=classes, class_mode='categorical', batch_size=batch_size, shuffle=True) validation_generator = test_datagen.flow_from_directory( validation_data_dir, target_size=(img_rows, img_cols), color_mode='rgb', classes=classes, class_mode='categorical', batch_size=batch_size, shuffle=True) # Fine-tuning history = model.fit_generator( train_generator, samples_per_epoch=nb_train_samples, nb_epoch=nb_epoch, validation_data=validation_generator, nb_val_samples=nb_val_samples) model.save_weights(os.path.join(result_dir, '20180802.h5')) save_history(history, os.path.join(result_dir, '20180802.txt'))
- 使用したCNNモデル:
走行テスト部分
*ラズベリーパイ3
webカメラから動画のストリーミングを行う.
*サーバ
ストリーミングされた動画をキャプチャし,その画像を学習済モデルへ入力し分類を行う.その分類結果をMQTTを用いてモータの制御信号を送信する.
#!/usr/bin/env python#!/usr/bin/env python#coding:utf-8import cv2import matplotlib.pyplot as pltfrom IPython import display import numpy as npfrom io import BytesIOfrom PIL import Imagefrom PIL import ImageOps from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictionsfrom keras.preprocessing import imageimport time import socketimport numpy as npimport cv2#for predictionimport osimport sysfrom keras.applications.vgg16 import VGG16from keras.models import Sequential, Modelfrom keras.layers import Input, Activation, Dropout, Flatten, Densefrom keras.preprocessing import image import paho.mqtt.client as mqttimport threading# Load VGG16#model = VGG16(weights='imagenet')result_dir = 'results'classes = ['foward_1', 'foward_2', 'left_1', 'left_2', 'right_1', 'right_2', 'other'] nb_classes = len(classes) img_height, img_width = 150, 150channels = 3 pre_count = [] # VGG16input_tensor = Input(shape=(img_height, img_width, channels))vgg16 = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor) # FCfc = Sequential()fc.add(Flatten(input_shape=vgg16.output_shape[1:]))fc.add(Dense(256, activation='relu'))fc.add(Dropout(0.5))fc.add(Dense(nb_classes, activation='softmax')) # VGG16とFCを接続model = Model(input=vgg16.input, output=fc(vgg16.output)) # 学習済みの重みをロードmodel.load_weights(os.path.join(result_dir, '20180803_3.h5')) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) URL = "http://192.168.0.4:8080/?action=stream"vc = cv2.VideoCapture(URL)#/////////////////////////////////////////////////////////// def cap (frame, client): frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # makes the blues image look real colored #webcam_preview.set_data(frame) plt.draw() # display.clear_output(wait=True) # display.display(plt.gcf()) plt.pause(0.01) img = Image.fromarray(np.uint8(frame)) img = img.resize((150, 150)) x = image.img_to_array(img) pred_data = np.expand_dims(x, axis=0) #print("show") plt.imshow(img) #plt.show() # print("[------------------------ PREDICT ------------------------]\n") # preds = model.predict(preprocess_input(pred_data)) # #time.sleep(1) # #print(preds) # results = decode_predictions(preds, top=1)[0] # for result in results: # #print(result) # if results[0] == "[n03814639]" or "[n03483316]": # print("s") # print("[{}] {:<30} {}%".format(result[0c], result[1],round(result[2]*100, 2))) pred = model.predict(pred_data)[0] #print(type(int(pred.argsort()[-1:][::-1]))) #from pynput.keyboard import Key, Listener #print("pre_count",len(pre_count)) # host = '192.168.0.18' # port = 1883 # keepalive = 60 # topic = 'topic/moter/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) if int(pred.argsort()[-1:][::-1]) == 0: print('foward_1') client.publish(topic, str(310) + "," + str(300)) #time.sleep(0.2) #client.publish(topic, str(0) + "," + str(0)) elif int(pred.argsort()[-1:][::-1]) == 1: print('foward_2') client.publish(topic, str(410) + "," + str(400)) #time.sleep(0.2) #client.publish(topic, str(0) + "," + str(0)) elif int(pred.argsort()[-1:][::-1]) == 2: print('left_1') client.publish(topic, str(500) + "," + str(200)) #time.sleep(0.2) # client.publish(topic, str(40) + "," + str(50)) # time.sleep(0.7) # client.publish(topic, str(60) + "," + str(25)) # time.sleep(0.3) #client.publish(topic, str(0) + "," + str(0)) elif int(pred.argsort()[-1:][::-1]) == 3: print('left_2') client.publish(topic, str(300) + "," + str(200)) #time.sleep(0.2) # client.publish(topic, str(40) + "," + str(50)) # time.sleep(0.7) # client.publish(topic, str(60) + "," + str(25)) # time.sleep(0.3) #client.publish(topic, str(0) + "," + str(0)) elif int(pred.argsort()[-1:][::-1]) == 4: print('right_1') client.publish(topic, str(200) + "," + str(500)) #time.sleep(0.2) # client.publish(topic, str(50) + "," + str(40)) # time.sleep(0.7) # client.publish(topic, str(25) + "," + str(60)) # time.sleep(0.3) #client.publish(topic, str(0) + "," + str(0)) elif int(pred.argsort()[-1:][::-1]) == 5: print('right_2') client.publish(topic, str(200) + "," + str(300)) #time.sleep(0.2) # client.publish(topic, str(50) + "," + str(40)) # time.sleep(0.7) # client.publish(topic, str(25) + "," + str(60)) # time.sleep(0.3) #client.publish(topic, str(0) + "," + str(0)) elif int(pred.argsort()[-1:][::-1]) == 6: print('other') client.publish(topic, str(-200) + "," + str(300)) #予測確率が高いトップ5を出力 top = 3 top_indices = pred.argsort()[-top:][::-1] result = [(classes[i], pred[i]) for i in top_indices] print(result) print("") #///////////////////////////////////////////////////////# while True:# ret, img = vc.read()# cv2.imshow("Stream Video",img)# print(img.shape) # Capture webcamera photo# vc = cv2.VideoCapture(0) host = '192.168.0.17'port = 1883keepalive = 60topic = 'topic/motor/dt'client = mqtt.Client()client.connect(host, port, keepalive) if vc.isOpened(): # try to get the first frame is_capturing, frame = vc.read() #cv2.imshow("Stream Video",frame) #frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # makes the blues image look real colored #webcam_preview = plt.imshow(frame) else: is_capturing = False # Push ■ Button!!while is_capturing: try: # Lookout for a keyboardInterrupt to stop the script #print("pre") # time.sleep(3) is_capturing, frame = vc.read() pre_count.append(1) if int(len(pre_count)) % 5 == 0: # thread = threading.Thread(target=cap, args=(frame,)) # thread.start() cap(frame, client) # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # makes the blues image look real colored # #webcam_preview.set_data(frame) # plt.draw() # # display.clear_output(wait=True) # # display.display(plt.gcf()) # #plt.pause(0.01) # img = Image.fromarray(np.uint8(frame)) # img = img.resize((150, 150)) # x = image.img_to_array(img) # pred_data = np.expand_dims(x, axis=0) # print("show") # plt.imshow(img) # # plt.show() # # print("[------------------------ PREDICT ------------------------]\n") # # preds = model.predict(preprocess_input(pred_data)) # # #time.sleep(1) # # #print(preds) # # results = decode_predictions(preds, top=1)[0] # # for result in results: # # #print(result) # # if results[0] == "[n03814639]" or "[n03483316]": # # print("s") # # print("[{}] {:<30} {}%".format(result[0c], result[1],round(result[2]*100, 2))) # pred = model.predict(pred_data)[0] # #print(type(int(pred.argsort()[-1:][::-1]))) # #from pynput.keyboard import Key, Listener # print("pre_count",len(pre_count)) # # host = '192.168.0.18' # # port = 1883 # # keepalive = 60 # # topic = 'topic/moter/dt' # # client = mqtt.Client() # # client.connect(host, port, keepalive) # if int(pred.argsort()[-1:][::-1]) == 0: # print('foward_1') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(200) + "," + str(200)) # #time.sleep(0.2) # #client.publish(topic, str(0) + "," + str(0)) # elif int(pred.argsort()[-1:][::-1]) == 1: # print('foward_2') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(200) + "," + str(200)) # #time.sleep(0.2) # #client.publish(topic, str(0) + "," + str(0)) # elif int(pred.argsort()[-1:][::-1]) == 2: # print('left_1') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(300) + "," + str(100)) # #time.sleep(0.2) # # client.publish(topic, str(40) + "," + str(50)) # # time.sleep(0.7) # # client.publish(topic, str(60) + "," + str(25)) # # time.sleep(0.3) # #client.publish(topic, str(0) + "," + str(0)) # elif int(pred.argsort()[-1:][::-1]) == 3: # print('left_2') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(200) + "," + str(50)) # #time.sleep(0.2) # # client.publish(topic, str(40) + "," + str(50)) # # time.sleep(0.7) # # client.publish(topic, str(60) + "," + str(25)) # # time.sleep(0.3) # #client.publish(topic, str(0) + "," + str(0)) # elif int(pred.argsort()[-1:][::-1]) == 4: # print('right_1') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(100) + "," + str(300)) # #time.sleep(0.2) # # client.publish(topic, str(50) + "," + str(40)) # # time.sleep(0.7) # # client.publish(topic, str(25) + "," + str(60)) # # time.sleep(0.3) # #client.publish(topic, str(0) + "," + str(0)) # elif int(pred.argsort()[-1:][::-1]) == 5: # print('right_2') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(50) + "," + str(200)) # #time.sleep(0.2) # # client.publish(topic, str(50) + "," + str(40)) # # time.sleep(0.7) # # client.publish(topic, str(25) + "," + str(60)) # # time.sleep(0.3) # #client.publish(topic, str(0) + "," + str(0)) # elif int(pred.argsort()[-1:][::-1]) == 6: # print('other') # host = '192.168.0.17' # port = 1883 # keepalive = 60 # topic = 'topic/motor/dt' # client = mqtt.Client() # client.connect(host, port, keepalive) # client.publish(topic, str(200) + "," + str(50)) #import paho.mqtt.client as mqtt #from pynput.keyboard import Key, Listener # host = '192.168.0.7' # port = 1883 # keepalive = 60 # topic = 'mqtt/test' # client = mqtt.Client() # client.connect(host, port, keepalive) # if classes[i] == 'foward': # print('foward') # #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0)) # elif classes[i] == 'right': # print('right') # #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0)) # elif classes[i] == 'left': # print('left') #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0)) #client.publish(topic, str(0.5) + "," + str(0) + "," + str(0) + "," + str(0)) # the pause time is = 1 / frameratef except KeyboardInterrupt: vc.release() is_capturing = False
*EV3
受信した制御信号をモータへ反映し,モータを駆動させ制御を行う.
# !/usr/bin/env python3 import paho.mqtt.client as mqtt from ev3dev.auto import * count = [] ma = Motor('outA') md = Motor('outD') def on_connect(client, userdata, flags, rc): print("Connected with result code " + str(rc)) client.subscribe("topic/motor/dt") def on_message(client, userdata, msg): msg_str = msg.payload.decode("utf-8") msg_array = msg_str.split(",") ma.speed_sp = msg_array[0] md.speed_sp = msg_array[1] count.append(1) print(len(count)) print(msg_array) #ma.duty_cycle_sp = msg_array[0] #time.sleep(3) #ma.stop() #md.duty_cycle_sp = msg_array[1] # time.sleep(3) #ma.stop() ma.run_timed(time_sp=250,speed_sp=ma.speed_sp,stop_action='brake') md.run_timed(time_sp=250,speed_sp=md.speed_sp,stop_action='brake') #time.sleep(1) client = mqtt.Client() client.connect("192.168.0.17" ,1883 ,60) client.on_connect = on_connect client.on_message = on_message ma.run_direct() md.run_direct() ma.duty_cycle_sp = 0 md.duty_cycle_sp = 0 client.loop_forever()