딥러닝 - SAM + CLIP을 접목한 엽면적 측정기 개발
SAM(Segmentation Anything Model)과 CLIP을 접목한 Leaf Segmentation 및 엽면적 계산



PlantCV를 이용한 엽면적 측정기 개발에 이어 SAM + CLIP 엽면적 측정기 개발에 대한 포스트를 이어서 진행하고자 한다. 이전 포스트인 PlantCV 엽면적 측정기를 먼저 읽고 오는 것을 추천한다.


SAM (Segmentation Anything Model)

SAM (Segmentation Anything Model)은 이전 딥러닝 포스트에서도 여러번 다루었던 모델이다. SAM 모델에 대하여 간략하게 설명하자면 다음과 같다.


SAM의 원리

SAM의 원리는 다음 참고 자료에 잘 설명 되어 있다.

Segment Anything 논문 정리
박진우 GitHub


SAM GitHub

GitHub: https://github.com/facebookresearch/segment-anything
Demo: https://segment-anything.com/demo


CLIP (Contrastive Language-Image Pre-training model)


CLIP GitHub

GitHub: https://github.com/openai/CLIP


SAM with CLIP

알고리즘의 순서를 간단하게 표현하면 다음과 같다.



FastSAM

SAM의 경우 Python 3.8이상 Pytorch 1.7이상 Torchvision 0.8 이상인 환경에서만 구동이 가능하다. 하지만 Jetson Nano는 기본 환경이 Python 3.6이며, Jetson nano Jetpack 환경은 파이썬 3.8이상의 라이브러리를 지원하지 않기 때문에 직접 Pytorch, OpenCV를 빌드해서 사용해야 했다.

기존의 SAM 모델의 경우 GPU가 아닌 CPU로 세팅하여 실행할 경우 계산 시간이 매우 오래걸리는 단점이 존재하였다. 원래 계획은 본 SAM + CLIP 모델을 Jetson Nano에 임베딩할 생각이였으나, Jetson Nano의 메모리가 CLIP을 실행하기에는 부족한 한계가 있었다.

이에 따라 SAM + CLIP 모델은 따로 데스크탑에서 구축한 서버에서 구동하고, Jetson은 아두이노와의 통신 및 카메라로 촬영한 이미지를 서버에 전송하는 중간 역할로 변경하게 되었다. 딥러닝 구동 서버는 FastAPI를 이용하여 구동하였다.


FastSAM은 기존의 SAM에서 사용하던 SA-1B 데이터 세트의 2%만을 사용하여 훈련된 CNN 모델이다. FastSAM은 50배 빠른 런타임 속도로 SAM 방식과 비슷한 성능을 보였다.



FastSAM GitHub

GitHub: https://github.com/CASIA-IVA-Lab/FastSAM


소스코드

위 깃허브 링크로 부터 FastSAM을 다운로드하고, 실행파일인 Inference.py를 이용하여 엽면적을 계산하도록 수정하였다.

import os
import argparse
from fastsam import FastSAM, FastSAMPrompt
import ast
import torch
import numpy as np
import pandas as pd
from PIL import Image
from utils.tools import convert_box_xywh_to_xyxy

def get_config():
    args = argparse.Namespace()
    args.model_path = "./weights/FastSAM-x.pt"
    args.img_path = "./uploads/"
    args.imgsz = 1024
    args.iou = 0.9
    args.text_prompt = "leaf"
    args.conf = 0.4
    args.output = "./output/"
    args.randomcolor = True
    args.point_prompt = "[[0,0]]"
    args.point_label = "[0]"
    args.box_prompt = "[[0,0,0,0]]"
    args.better_quality = False
    args.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
    args.retina = True
    args.withContours = False
    return args

def process_image(img_path, args):
    model = FastSAM(args.model_path)
    args.point_prompt = ast.literal_eval(args.point_prompt)
    args.box_prompt = convert_box_xywh_to_xyxy(ast.literal_eval(args.box_prompt))
    args.point_label = ast.literal_eval(args.point_label)
    input = Image.open(img_path)
    input = input.convert("RGB")
    everything_results = model(
        input,
        device=args.device,
        retina_masks=args.retina,
        imgsz=args.imgsz,
        conf=args.conf,
        iou=args.iou
    )
    bboxes = None
    points = None
    point_label = None
    prompt_process = FastSAMPrompt(input, everything_results, device=args.device)
    if args.box_prompt[0][2] != 0 and args.box_prompt[0][3] != 0:
        ann = prompt_process.box_prompt(bboxes=args.box_prompt)
        bboxes = args.box_prompt

    elif args.text_prompt != None:
        ann = prompt_process.text_prompt(text=args.text_prompt)
        pixel_count = np.sum(ann)
        leaf_area = pixel_count * 0.000508
        leaf_area = round(leaf_area, 2)
        print("실제 잎의 면적: ", leaf_area, 'cm^2')

    elif args.point_prompt[0] != [0, 0]:
        ann = prompt_process.point_prompt(
            points=args.point_prompt, pointlabel=args.point_label
        )
        points = args.point_prompt
        point_label = args.point_label
    else:
        ann = prompt_process.everything_prompt()

    output_filename = os.path.splitext(os.path.basename(img_path))[0] + '.jpg'
    output_path = os.path.join(args.output, output_filename)

    prompt_process.plot(
        annotations=ann,
        output_path=output_path,
        bboxes=bboxes,
        points=points,
        point_label=point_label,
        withContours=args.withContours,
        better_quality=args.better_quality,
    )

    return leaf_area

def main(args):
    original_point_prompt = args.point_prompt
    original_box_prompt = args.box_prompt
    original_point_label = args.point_label

    results = []
    image_names = [] 
    for img_file in sorted(os.listdir(args.img_path)):
        if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(args.img_path, img_file)
            args.point_prompt = original_point_prompt
            args.box_prompt = original_box_prompt
            args.point_label = original_point_label
            results.append(process_image(img_path, args))
            image_names.append(img_file)

    df = pd.DataFrame({'Filename': image_names, 'Leaf_Area': results})
    df.to_csv('results.csv', index=False)

    print(results)

if __name__ == "__main__":
    args = get_config()
    main(args)


Arudino


#include <Wire.h>
#include <LiquidCrystal_I2C.h>

LiquidCrystal_I2C lcd(0x27, 16, 2);

const int debounceTime = 500;  // 딜레이 시간을 밀리초 단위로 설정

void setup() {
  lcd.init();
  lcd.backlight();
  Serial.begin(9600);
  pinMode(5, INPUT_PULLUP); // 5번 핀 ==> Capture 버튼
  pinMode(4, INPUT_PULLUP); // 4번 핀 ==> Send 버튼
}

void loop() {
  int buttonState = digitalRead(5);
  int exitButtonState = digitalRead(4);

  if (buttonState == LOW) {
      Serial.println("CAPTURE");
      lcd.clear();
      lcd.print("Capture complete");
      delay(debounceTime);  // 버튼 입력 후 일정 시간 동안 대기
  }

  if (exitButtonState == LOW) {
      Serial.println("SEND");
      lcd.clear();
      lcd.print("Send Complete");
      delay(debounceTime);  // 버튼 입력 후 일정 시간 동안 대기
  }
}


FastAPI

FastAPI를 사용하여 딥러닝 API 서버를 구축하였다. server.py의 라우팅은 다음과 같다.


import ast
from fastapi import FastAPI, File, UploadFile, Request
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
import os
from typing import List
import Inference3
import pandas as pd

app = FastAPI()
templates = Jinja2Templates(directory="templates")
app.mount("/output", StaticFiles(directory="output"), name="output")

log_filename = "history.log"  # 로그 파일 이름
results_data = []

def log_results(results):
    with open(log_filename, "a") as log_file:
        for result in results:
            log_file.write(f"{result}\n")

@app.get("/", response_class=HTMLResponse)
async def upload_form(request: Request):
    return templates.TemplateResponse("upload_form.html", {"request": request})

def save_results_to_csv(results, filename="results.csv"):
    df = pd.DataFrame(results)
    df.to_csv(filename, index=False)

@app.post("/upload/")
async def upload_images(files: List[UploadFile] = File(...)):
    global results_data
    results = []

    for file in files:
        file_location = f"uploads/{file.filename}"
        with open(file_location, "wb+") as file_object:
            file_object.write(await file.read())  # 비동기 파일 읽기

        # Inference3 실행
        args = Inference3.get_config()
        args.img_path = file_location
        leaf_area = Inference3.process_image(file_location, args)

        output_image_path = os.path.join('output', os.path.splitext(file.filename)[0] + '.jpg')
        output_image_url = f'/output/{os.path.basename(output_image_path)}'

        results.append({"filename": file.filename, "leaf_area": leaf_area, "output_image_url": output_image_url})

    save_results_to_csv(results)
    results_data.append(results)
    log_results(results)  # 결과 로깅

    return results

@app.get("/download_csv/")
async def download_csv():
    return FileResponse("results.csv", media_type='application/octet-stream', filename="results.csv")

@app.get("/results/")
async def get_results():
    global results_data
    return results_data

@app.get("/history/", response_class=HTMLResponse)
async def view_history(request: Request):
    with open("history.log", "r") as log_file:
        history_data = [ast.literal_eval(line) for line in log_file.readlines()]
    return templates.TemplateResponse("history_page.html", {"request": request, "history": history_data})

@app.get("/view-results/", response_class=HTMLResponse)
async def view_results(request: Request):
    return templates.TemplateResponse("results_page.html", {"request": request})

# if __name__ == "__main__":
#     import uvicorn
#     uvicorn.run(app="server:app",
#                     host="0.0.0.0",
#                     port=8001,
#                     reload=True)


Python Arduino-Server 연결


import serial
import requests
import os
import cv2
from datetime import datetime

url = "http://113.198.63.26:13392/upload/"
ser = serial.Serial('/dev/ttyUSB0', 9600, timeout=1)

capture_counter = 0
def gstreamer_pipeline():
    return (
        "nvarguscamerasrc ! "
        "video/x-raw(memory:NVMM), width=(int)1280, height=(int)1080, format=(string)NV12, framerate=(fraction)30/1 ! "
        "nvvidconv ! video/x-raw, format=(string)BGRx ! "
        "videoconvert ! video/x-raw, format=(string)BGR ! appsink"
    )

def capture_image(filename):
    if not os.path.exists('capture_images'):
        os.makedirs('capture_images')

    full_filename = os.path.join('capture_images', filename)
    cap = cv2.VideoCapture(gstreamer_pipeline(), cv2.CAP_GSTREAMER)
    if cap.isOpened():
        ret, frame = cap.read()
        if ret:
            cv2.imwrite(full_filename, frame)
        cap.release()
    else:
        print("Unable to open the camera")

if __name__=='__main__':
    while True:
        if ser.in_waiting > 0:
            line = ser.readline().decode('utf-8').rstrip()

            if line == "CAPTURE":
                now = datetime.now()
                filename = f"image_{now.strftime('%Y%m%d_%H%M')}_{capture_counter}.jpg"  # 파일명에 카운터 추가
                capture_image(filename)
                capture_counter += 1

            elif line == "SEND":
                image_files = [os.path.join('capture_images', file) for file in os.listdir('capture_images') if file.endswith('.jpg') or file.endswith('.png')]
                files = [('files', (open(file, 'rb'))) for file in image_files]

                response = requests.post(url, files=files)
                print(response.text)

                for file in image_files:
                    os.remove(file)
                break


Templates

FastAPI로 구축한 서버로 부터 라우팅 된 결과는 아래 템플릿을 통해 웹에 출력된다.

먼저 results_page.html의 경우 코드의 핵심은 다음과 같다.

Bootstrap css

<link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css">


결과 표시 영역

<div id="results" class="row justify-content-center">
</div>

서버로부터 받은 이미지 처리 결과를 동적으로 표시할 영역을 정의한다.


자바스크립트

fetch('/results/')
    .then(response => response.json())
    .then(data => {
        // ... 결과를 처리하여 페이지에 표시 
    });

페이지가 로드될 때 서버로부터 이미지 처리 결과를 비동기적으로 요청하고, 받은 데이터를 웹 페이지에 동적으로 표시하는 자바스크립트 코드이다.


Flow chart


결과



성능