on
딥러닝 - SAM + CLIP을 접목한 엽면적 측정기 개발
PlantCV를 이용한 엽면적 측정기 개발에 이어 SAM + CLIP 엽면적 측정기 개발에 대한 포스트를 이어서 진행하고자 한다. 이전 포스트인 PlantCV 엽면적 측정기를 먼저 읽고 오는 것을 추천한다.
SAM (Segmentation Anything Model)
SAM (Segmentation Anything Model)은 이전 딥러닝 포스트에서도 여러번 다루었던 모델이다. SAM 모델에 대하여 간략하게 설명하자면 다음과 같다.
-
SAM은 입력한 텍스트 프롬프트나 클릭한 지점을 기반으로 AI가 이미지 내에서 특정 물체를 분리해주는 이미지 분할 모델이다. 이미지 내에서 특정 물체를 식별하고 분리할 수 있게 해준다.
-
“이미지 세그멘테이션 (segmentation)”을 위한 모델이며, 픽셀이 어떤 객체에 속해있는지를 식별하는 알고리즘이다.
-
Meta는 SAM 과 40억개의 파라미터 데이터셋 (SA-1B) 모델을 공개하였다.
-
SAM은 물체가 무엇인지에 대한 일반적인 개념을 학습 했고, 훈련중에 만나지 못한 물체 및 이미지 유형에 대해서도 이미지/비디오의 모든 객체에 대해서 마스크를 생성 가능 추가적인 훈련 없이도 새로운 이미지 도메인(물속 사진이나, 세포 현미경 사진등)에도 적용 가능하다.
SAM의 원리
SAM의 원리는 다음 참고 자료에 잘 설명 되어 있다.
Segment Anything 논문 정리
박진우 GitHub
SAM GitHub
GitHub: https://github.com/facebookresearch/segment-anything
Demo: https://segment-anything.com/demo
CLIP (Contrastive Language-Image Pre-training model)
-
CLIP 모델은 ViT(Vision Transformer)와 Transformer 언어 모델(Transformer-based language model)을 결합하여 이미지와 텍스트를 모두 처리할 수 있게 만들어놓은 모델이다.
-
여기서 ViT란 비지도학습을 통해 이미지에서 특징을 추출할 수 있도록 만들어진 CNN 모델이며, Transformer 언어 모델은 사전훈련(pre-trained)을 통해 텍스트 데이터를 학습해놓은 모델이다.
CLIP GitHub
GitHub: https://github.com/openai/CLIP
SAM with CLIP
알고리즘의 순서를 간단하게 표현하면 다음과 같다.
-
이미지가 SAM Image Encoder를 통과하여 이미지 내 모든 객체를 segmentation 한다.
-
CLIP의 text Encoder에 “Leaf”를 할당하고, CLIP이 segmentation한 객체들 중 Leaf에 해당하는 영역을 찾아낸다.
-
이미지에서 Leaf에 해당하는 영역의 마스크만 출력한다.
FastSAM
SAM의 경우 Python 3.8이상 Pytorch 1.7이상 Torchvision 0.8 이상인 환경에서만 구동이 가능하다. 하지만 Jetson Nano는 기본 환경이 Python 3.6이며, Jetson nano Jetpack 환경은 파이썬 3.8이상의 라이브러리를 지원하지 않기 때문에 직접 Pytorch, OpenCV를 빌드해서 사용해야 했다.
기존의 SAM 모델의 경우 GPU가 아닌 CPU로 세팅하여 실행할 경우 계산 시간이 매우 오래걸리는 단점이 존재하였다. 원래 계획은 본 SAM + CLIP 모델을 Jetson Nano에 임베딩할 생각이였으나, Jetson Nano의 메모리가 CLIP을 실행하기에는 부족한 한계가 있었다.
이에 따라 SAM + CLIP 모델은 따로 데스크탑에서 구축한 서버에서 구동하고, Jetson은 아두이노와의 통신 및 카메라로 촬영한 이미지를 서버에 전송하는 중간 역할로 변경하게 되었다.
딥러닝 구동 서버는 FastAPI
를 이용하여 구동하였다.
FastSAM은
기존의 SAM에서 사용하던 SA-1B 데이터 세트의 2%만을 사용하여 훈련된 CNN 모델이다.
FastSAM은
50배 빠른 런타임 속도로 SAM 방식과 비슷한 성능을 보였다.
FastSAM GitHub
GitHub: https://github.com/CASIA-IVA-Lab/FastSAM
소스코드
위 깃허브 링크로 부터 FastSAM을 다운로드하고, 실행파일인 Inference.py
를 이용하여 엽면적을 계산하도록 수정하였다.
-
text_prompt
를leaf
로 고정한 뒤, Jetson으로 부터 전송된 이미지가 업로드되는 폴더를img_path
로 지정한다. 모델은FastSAM-x.pt
로 고정한다. -
이때
text_prompt
가None
이 아니기 때문에process_image
함수의elif args.text_prompt != None: ann = prompt_process.text_prompt(text=args.text_prompt)
가 활성화 된다. -
여기서
ann
은 마스킹 된 부분의 픽셀을 의미하고, 이를 모두 합산하면 마스킹 된 부분의 총 픽셀수가 된다. 여기에 앞서 구한 실제 엽면적과 이미지 내 픽셀수 간의 비율을 곱하면 실제 엽면적을 산정할 수 있다.
import os
import argparse
from fastsam import FastSAM, FastSAMPrompt
import ast
import torch
import numpy as np
import pandas as pd
from PIL import Image
from utils.tools import convert_box_xywh_to_xyxy
def get_config():
args = argparse.Namespace()
args.model_path = "./weights/FastSAM-x.pt"
args.img_path = "./uploads/"
args.imgsz = 1024
args.iou = 0.9
args.text_prompt = "leaf"
args.conf = 0.4
args.output = "./output/"
args.randomcolor = True
args.point_prompt = "[[0,0]]"
args.point_label = "[0]"
args.box_prompt = "[[0,0,0,0]]"
args.better_quality = False
args.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
args.retina = True
args.withContours = False
return args
def process_image(img_path, args):
model = FastSAM(args.model_path)
args.point_prompt = ast.literal_eval(args.point_prompt)
args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt))
args.point_label = ast.literal_eval(args.point_label)
input = Image.open(img_path)
input = input.convert("RGB")
everything_results = model(
input,
device=args.device,
retina_masks=args.retina,
imgsz=args.imgsz,
conf=args.conf,
iou=args.iou
)
bboxes = None
points = None
point_label = None
prompt_process = FastSAMPrompt(input, everything_results, device=args.device)
if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0:
ann = prompt_process.box_prompt(bboxes=args.box_prompt)
bboxes = args.box_prompt
elif args.text_prompt != None:
ann = prompt_process.text_prompt(text=args.text_prompt)
pixel_count = np.sum(ann)
leaf_area = pixel_count * 0.000508
leaf_area = round(leaf_area, 2)
print("실제 잎의 면적: ", leaf_area, 'cm^2')
elif args.point_prompt[0] != [0, 0]:
ann = prompt_process.point_prompt(
points=args.point_prompt, pointlabel=args.point_label
)
points = args.point_prompt
point_label = args.point_label
else:
ann = prompt_process.everything_prompt()
output_filename = os.path.splitext(os.path.basename(img_path))[0] + '.jpg'
output_path = os.path.join(args.output, output_filename)
prompt_process.plot(
annotations=ann,
output_path=output_path,
bboxes=bboxes,
points=points,
point_label=point_label,
withContours=args.withContours,
better_quality=args.better_quality,
)
return leaf_area
def main(args):
original_point_prompt = args.point_prompt
original_box_prompt = args.box_prompt
original_point_label = args.point_label
results = []
image_names = []
for img_file in sorted(os.listdir(args.img_path)):
if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(args.img_path, img_file)
args.point_prompt = original_point_prompt
args.box_prompt = original_box_prompt
args.point_label = original_point_label
results.append(process_image(img_path, args))
image_names.append(img_file)
df = pd.DataFrame({'Filename': image_names, 'Leaf_Area': results})
df.to_csv('results.csv', index=False)
print(results)
if __name__ == "__main__":
args = get_config()
main(args)
Arudino
- PlantCV와 마찬가지로 5번핀에 연결된 버튼을 누으면 진행, 4번핀에 연결된 버튼을 누르면 종료가 되는 구조는 동일하다. 다만, PlantCV를 사용했을 경우에는 Jetson Nano에서 이미지를 처리하고 결과를 내었다면, 이번엔 이미지를 서버에 보내서 처리해야하기 때문에 5번핀 버튼을 누르면 사진촬영, 4번핀 버튼을 누르면 종료 및 사진 전송이 된다.
#include <Wire.h>
#include <LiquidCrystal_I2C.h>
LiquidCrystal_I2C lcd(0x27, 16, 2);
const int debounceTime = 500; // 딜레이 시간을 밀리초 단위로 설정
void setup() {
lcd.init();
lcd.backlight();
Serial.begin(9600);
pinMode(5, INPUT_PULLUP); // 5번 핀 ==> Capture 버튼
pinMode(4, INPUT_PULLUP); // 4번 핀 ==> Send 버튼
}
void loop() {
int buttonState = digitalRead(5);
int exitButtonState = digitalRead(4);
if (buttonState == LOW) {
Serial.println("CAPTURE");
lcd.clear();
lcd.print("Capture complete");
delay(debounceTime); // 버튼 입력 후 일정 시간 동안 대기
}
if (exitButtonState == LOW) {
Serial.println("SEND");
lcd.clear();
lcd.print("Send Complete");
delay(debounceTime); // 버튼 입력 후 일정 시간 동안 대기
}
}
FastAPI
FastAPI를 사용하여 딥러닝 API 서버를 구축하였다. server.py
의 라우팅은 다음과 같다.
-
@app.get("/")
: 루트 URL로 접속 시 HTML 업로드 폼을 제공한다. -
@app.post("/upload/")
: 사용자가 이미지를 업로드하면, 해당 이미지를 처리하고 결과를 반환하며, 이미지는 uploads 폴더에 저장된다. 이후Inference3.process_image
함수를 호출하여 이미지를 처리한다. 결과는 CSV 파일에 저장되고 로그 파일에 기록된다. -
@app.get("/download_csv/")
: 결과 CSV 파일을 다운로드할 수 있는 라우트를 제공한다. -
@app.get("/results/")
: 처리된 결과 데이터를 JSON 형태로 반환한다. -
@app.get("/history/")
: 처리 기록을 보여주는 페이지를 제공한다. -
@app.get("/view-results/")
: 처리 결과를 보여주는 페이지를 제공한다. -
결과는
results
리스트에 딕셔너라 형태로 추가되며, 이를 csv 결과로 제공한다.
import ast
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
import os
from typing import List
import Inference3
import pandas as pd
app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/output", StaticFiles(directory="output"), name="output")
log_filename = "history.log" # 로그 파일 이름
results_data = []
def log_results(results):
with open(log_filename, "a") as log_file:
for result in results:
log_file.write(f"{result}\n")
@app.get("/", response_class=HTMLResponse)
async def upload_form(request: Request):
return templates.TemplateResponse("upload_form.html", {"request": request})
def save_results_to_csv(results, filename="results.csv"):
df = pd.DataFrame(results)
df.to_csv(filename, index=False)
@app.post("/upload/")
async def upload_images(files: List[UploadFile] = File(...)):
global results_data
results = []
for file in files:
file_location = f"uploads/{file.filename}"
with open(file_location, "wb+") as file_object:
file_object.write(await file.read()) # 비동기 파일 읽기
# Inference3 실행
args = Inference3.get_config()
args.img_path = file_location
leaf_area = Inference3.process_image(file_location, args)
output_image_path = os.path.join('output', os.path.splitext(file.filename)[0] + '.jpg')
output_image_url = f'/output/{os.path.basename(output_image_path)}'
results.append({"filename": file.filename, "leaf_area": leaf_area, "output_image_url": output_image_url})
save_results_to_csv(results)
results_data.append(results)
log_results(results) # 결과 로깅
return results
@app.get("/download_csv/")
async def download_csv():
return FileResponse("results.csv", media_type='application/octet-stream', filename="results.csv")
@app.get("/results/")
async def get_results():
global results_data
return results_data
@app.get("/history/", response_class=HTMLResponse)
async def view_history(request: Request):
with open("history.log", "r") as log_file:
history_data = [ast.literal_eval(line) for line in log_file.readlines()]
return templates.TemplateResponse("history_page.html", {"request": request, "history": history_data})
@app.get("/view-results/", response_class=HTMLResponse)
async def view_results(request: Request):
return templates.TemplateResponse("results_page.html", {"request": request})
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app="server:app",
# host="0.0.0.0",
# port=8001,
# reload=True)
Python Arduino-Server 연결
-
http://113.198.63.26:13392 주소로 서버를 열고 포트포워딩하여 외부 IP에서도 접속이 가능하도록 세팅하였다.
-
PlantCV를 사용한 예시와 마찬가지로 SCI Camera Module을 사용하기 위해서는 파이프라인을 만들어주고, OpenCV를 사용하여 이미지를 캡쳐해야한다.
-
아두이노 버튼제어로 부터 들어온 신호에 따라 5번핀 버튼 신호가 들어오면 이미지를 캡쳐하고, 4번핀 버튼 신호가 들어오면, 이미지를 캡쳐하여 저장한 폴더 내에 있는 이미지를 모두 서버로 전송한다.
-
전송한 이후 캡쳐한 이미지를 저장하는 폴더는 리셋하여 이미지가 중복으로 전송되는 것을 방지한다.
import serial
import requests
import os
import cv2
from datetime import datetime
url = "http://113.198.63.26:13392/upload/"
ser = serial.Serial('/dev/ttyUSB0', 9600, timeout=1)
capture_counter = 0
def gstreamer_pipeline():
return (
"nvarguscamerasrc ! "
"video/x-raw(memory:NVMM), width=(int)1280, height=(int)1080, format=(string)NV12, framerate=(fraction)30/1 ! "
"nvvidconv ! video/x-raw, format=(string)BGRx ! "
"videoconvert ! video/x-raw, format=(string)BGR ! appsink"
)
def capture_image(filename):
if not os.path.exists('capture_images'):
os.makedirs('capture_images')
full_filename = os.path.join('capture_images', filename)
cap = cv2.VideoCapture(gstreamer_pipeline(), cv2.CAP_GSTREAMER)
if cap.isOpened():
ret, frame = cap.read()
if ret:
cv2.imwrite(full_filename, frame)
cap.release()
else:
print("Unable to open the camera")
if __name__=='__main__':
while True:
if ser.in_waiting > 0:
line = ser.readline().decode('utf-8').rstrip()
if line == "CAPTURE":
now = datetime.now()
filename = f"image_{now.strftime('%Y%m%d_%H%M')}_{capture_counter}.jpg" # 파일명에 카운터 추가
capture_image(filename)
capture_counter += 1
elif line == "SEND":
image_files = [os.path.join('capture_images', file) for file in os.listdir('capture_images') if file.endswith('.jpg') or file.endswith('.png')]
files = [('files', (open(file, 'rb'))) for file in image_files]
response = requests.post(url, files=files)
print(response.text)
for file in image_files:
os.remove(file)
break
Templates
FastAPI로 구축한 서버로 부터 라우팅 된 결과는 아래 템플릿을 통해 웹에 출력된다.
-
results_page.html
: Jetson으로 부터 전송 받은 결과를 출력하는 웹 페이지 -
hitory_page.html
: 지금까지 처리했던 결과를history.log
에 저장하고, 해당 이미지 이름, 결과를 웹 상에 테이블 형태로 출력
먼저 results_page.html
의 경우 코드의 핵심은 다음과 같다.
Bootstrap css
<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">
결과 표시 영역
<div id="results" class="row justify-content-center">
</div>
서버로부터 받은 이미지 처리 결과를 동적으로 표시할 영역을 정의한다.
자바스크립트
fetch('/results/')
.then(response => response.json())
.then(data => {
// ... 결과를 처리하여 페이지에 표시
});
페이지가 로드될 때 서버로부터 이미지 처리 결과를 비동기적으로 요청하고, 받은 데이터를 웹 페이지에 동적으로 표시하는 자바스크립트 코드이다.
Flow chart
결과
results_page.html
의 결과
history_page.html
의 결과
성능
-
전반적으로 잎을 잘 분할하였다.
-
모델이 무거워 Jetson Nano에서는 실행이 불가능하였다. SAM은 무난하게 실행이 되나, CLIP이 실행이 불가능했다. 아직 Jetson Nano로 대형 딥러닝 모델을 실행하기에는 한계가 명확한 것 같다.
-
잎이 말려있는 경우는 당연하게도 결과에 차이가 나며, 카메라의 왜곡 현상으로 인하여 잎의 사이즈가 클 수록 외각 부분의 왜곡이 발생하여 오차가 발생한다.
-
여러 위치에서 촬영한 이미지를 바탕으로 사이즈를 산정하고, 위치에 따른 왜곡 정도를 머신러닝을 이용한 학습을 통해 극복하거나, CV의 홀로그램 기법과 같은 이미지 전처리를 통해 한계를 극복하는 방법이 필요할 것이다.
-
Li3100c와의 결과 비교