Controllable Diffusion Model

Knowledge Distillation(KD)대형 모델(Teacher)소형 모델(Student) 에 지식을 전달하여, 소형 모델의 복잡도를 증가시키지 않고 성능을 향상시키는 기법이다.
MMDetection(인기 있는 객체 탐지 프레임워크)에서 Knowledge Distillation(KD) 를 구현하기 위해서는, Teacher-Student 학습 패러다임을 사용하여 학습을 진행한다.

"Controllable Diffusion Model" 논문을 기반으로 한 객체 탐지 모델에 Knowledge Distillation을 통합하려면, trainer.py 파일과 config 파일을 수정하여 KD 프레임워크를 지원해야한다.

import os
import datetime
import uuid
import wandb
from mmcv import Config
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector, set_random_seed

# 추가: KD를 위한 Helper 함수들
from kd_utils import knowledge_distillation_loss

wandb.login()

def main():
    """메인 실행 함수"""
    
    # 실험 이름 생성
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    random_code = str(uuid.uuid4())[:5]
    
    # 설정 생성
    cfg, model_name, output_dir = create_config('atss')
    experiment_dir = os.path.join(output_dir, f"{timestamp}_{random_code}")
    os.makedirs(experiment_dir, exist_ok=True)
    cfg.work_dir = experiment_dir

    wandb.init(
        project="Object Detection", 
        dir=experiment_dir,
        name=f'{model_name}_{random_code}',
        config=cfg._cfg_dict.to_dict()
    )
    
    # Optimizer 설정
    cfg.optimizer = dict(
        type='SGD', 
        lr=wandb.config.lr, 
        momentum=0.9,
        weight_decay=wandb.config.weight_decay
    )
    
    # 데이터셋 및 모델 빌드
    datasets = [build_dataset(cfg.data.train)]
    student_model = build_detector(cfg.model)
    teacher_model = build_detector(cfg.model)  # 같은 구조로 teacher 모델 생성
    teacher_model.init_weights()
    teacher_model.eval()  # Teacher 모델 고정

    student_model.init_weights()

    # Knowledge Distillation Loss 적용
    for epoch in range(cfg.runner.max_epochs):
        for batch in datasets:
            # Forward pass: 학생 모델 예측
            student_pred = student_model(batch['img'])
            
            # Forward pass: 교사 모델 예측 (고정된 모델)
            with torch.no_grad():
                teacher_pred = teacher_model(batch['img'])
            
            # KD 손실 계산
            kd_loss = knowledge_distillation_loss(student_pred, teacher_pred, cfg.kd_loss_weight)

            # 학습 및 손실 업데이트
            loss = kd_loss + student_model.compute_loss(batch, student_pred)
            loss.backward()
            optimizer.step()

if __name__ == "__main__":
    sweep_configuration = {
        "method": "bayes",
        "metric": {"goal": "maximize", "name": "val/bbox_mAP_50"},
        "parameters": {
            "lr": {"max": 0.003, "min": 0.0001},
            "weight_decay": {"max": 0.01, "min": 0.0001}
        },
        "early_terminate":{
            "type": "hyperband",
            "s": 3,
            "eta": 2,
            "min_iter": 8,
        }
    }

    sweep_id = wandb.sweep(sweep=sweep_configuration, project='ATSS')
    wandb.agent(sweep_id, function=main, count=2)

2. Config File Adjustments

# Config 파일의 추가 내용
kd_loss_weight = 0.5  # 교사-학생 학습에서의 손실 가중치

model = dict(
    type='ATSS',
    backbone=dict(
        type='ResNet',
        depth=50,
        init_cfg=dict(type='Pretrained', checkpoint='open-mmlab://detectron2/resnet50_caffe'),
    ),
    neck=dict(type='FPN', in_channels=[256, 512, 1024, 2048], out_channels=256, num_outs=5),
    bbox_head=dict(
        type='ATSSHead',
        num_classes=80,  # 클래스 수 맞추기
    ),
    teacher=dict(  # Teacher 모델 설정 추가
        type='ATSS',
        pretrained='path/to/teacher_model.pth',  # Teacher 모델 경로 설정
        backbone=dict(type='ResNet', depth=101),  # 교사는 더 큰 모델일 가능성이 큼
    )
)

# 손실 함수 설정
kd_loss = dict(
    type='DistillationLoss',
    temperature=4,  # Knowledge Distillation 온도 하이퍼파라미터
)

3. Knowledge Distillation Loss Function (kd_utils.py)

import torch.nn.functional as F

def knowledge_distillation_loss(student_pred, teacher_pred, kd_weight=0.5, temperature=4):
    """
    Knowledge Distillation 손실 함수
    student_pred: 학생 모델의 예측 결과
    teacher_pred: 교사 모델의 예측 결과
    kd_weight: Knowledge Distillation 손실 가중치
    temperature: Knowledge Distillation의 온도
    """
    # softmax로 예측 결과를 변환
    teacher_probs = F.softmax(teacher_pred / temperature, dim=-1)
    student_probs = F.log_softmax(student_pred / temperature, dim=-1)
    
    # KL Divergence Loss 계산 (soft target 기반)
    kd_loss = F.kl_div(student_probs, teacher_probs, reduction='batchmean') * (temperature ** 2)
    
    return kd_weight * kd_loss