Giới thiệu

Table of Content


Import & configure

!pip install kaggle
from google.colab import files
uploaded = files.upload()
for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))
# Then move kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json
from google.colab import drive
drive.mount('./gdrive')
!kaggle datasets download -d brendanartley/lumbar-coordinate-pretraining-dataset
!unzip -qq "/content/lumbar-coordinate-pretraining-dataset.zip"
!pip install -q pytorch-lightning & pip install -q -U albumentations & pip install -q iterative-stratification
!pip install -q timm & pip install -q einops & pip install -q pytorch-lightning wandb & pip install torch-ema
!git clone https://github.com/mlpc-ucsd/CoaT
!pip install -q pydicom
import gc
import wandb
from pytorch_lightning.loggers import WandbLogger
import os
import yaml
import sys
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torch
from glob import glob
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW, Adam
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, TQDMProgressBar
import torchvision.transforms as T
import albumentations as A
import pandas.api.types
import sklearn.metrics
import timm
import scipy
import albumentations as A
from torchvision.transforms import v2
from torchvision import models
from tqdm.auto import tqdm
from joblib import Parallel, delayed
from torch.utils.data import default_collate
import pydicom as dcm
import transformers
SEED = 126 # friend's birthday
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True # Fix the network according to random seed
    print('Finish seeding with seed {}'.format(seed))
seed_everything(SEED)
print('Training on device {}'.format(device))
time
if config['debug']: 
    \#meta_df = pd.read_parquet('/kaggle/input/rsna-newmeta/meta.parquet')
    meta_df_list = []
    meta_df_list = Parallel(n_jobs=-1)([delayed(create_dcm_df)(row.study_id, row.series_id, row.series_description) for _, row in series.iterrows()])
    \#for _, row in tqdm(test_series.iterrows(), total=len(test_series)): 
    #    meta_df_list.append(create_dcm_df(row.study_id, row.series_id, row.series_description))
    meta_df = pd.concat(meta_df_list)
    del meta_df_list
    gc.collect()
    meta_df.to_parquet('meta.parquet')
else: 
    # spend about 20 min for train data
    meta_df_list = []
    meta_df_list = Parallel(n_jobs=-1)([delayed(create_dcm_df)(row.study_id, row.series_id, row.series_description) for _, row in series.iterrows()])
    \#for _, row in tqdm(test_series.iterrows(), total=len(test_series)): 
    #    meta_df_list.append(create_dcm_df(row.study_id, row.series_id, row.series_description))
    meta_df = pd.concat(meta_df_list)
    del meta_df_list
    gc.collect()
    meta_df.to_parquet('meta.parquet')

CPU times: user 290 ms, sys: 99.4 ms, total: 390 ms
Wall time: 3.09 s

`time prefix = ” import warnings warnings.filterwarnings(“ignore”) depth_predict = {‘scs’: { ‘L1/L2’:[], ‘L2/L3’: [], ‘L3/L4’: [], ‘L4/L5’: [], ‘L5/S1’: [] }, ‘nfn’: { ‘left_L1/L2’: [], ‘left_L2/L3’: [], ‘left_L3/L4’: [], ‘left_L4/L5’: [], ‘left_L5/S1’: [], ‘right_L1/L2’: [], ‘right_L2/L3’: [], ‘right_L3/L4’: [], ‘right_L4/L5’: [], ‘right_L5/S1’: [], } } model_path_dict = { ‘scs’: [ ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/scs_depth_1024_ssr_l1_4.ckpt’, ], ‘nfn’: [ ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_4.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_0.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_1.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_2.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_3.ckpt’, ‘/kaggle/input/rsna-spine-final-models/nfn_depth_1024_ssr_l1_4.ckpt’, ] } ##############DEPTH DETECT######################### for condition in [‘nfn’, ‘scs’]: print(condition) model_path_list = model_path_dict[condition] for model_path in model_path_list: _meta_df = meta_df.copy() #_series = series.copy() dataset_test = DepthDetectDataset(_meta_df, condition, ‘sub’) data_loader_test = DataLoader( dataset_test, batch_size=config[“test_bs”], shuffle=False, num_workers=4, pin_memory=False ) model_name = model_path.split(’/’)[-1] if ‘1024’ in model_name: widths = [128, 256, 512, 1024] else: widths = [64, 128, 256, 512] if ‘l1’ in model_name: model_type = ‘regression’ else: model_type = ‘classification’ model = DepthDetectModule.load_from_checkpoint(model_path, condition=condition, widths=widths, model_type=model_type) model.eval() model.zero_grad() model.to(device) pred_temp = {} for k in depth_predict[condition].keys(): pred_temp[k] = [] study_id_list = [] with torch.no_grad(): for data in tqdm(data_loader_test, total=len(data_loader_test)): images, study_id = data images = images.to(device) preds = model.forward(images) #print(preds) if model_type == ‘regression’: for k, v in preds.items(): pred_temp[k].append((v[:, -1]*32).to(‘cpu’).detach().numpy()) else: for k, v in preds.items(): pred_temp[k].append(torch.argmax(v, dim=1).to(‘cpu’).detach().numpy()) study_id_list.append(study_id.to(‘cpu’).reshape(-1).detach().numpy()) del images, study_id, preds gc.collect() for k, v in pred_temp.items(): depth_predict[condition][k].append(np.concatenate(v)) study_id = np.concatenate(study_id_list) del pred_temp, study_id_list gc.collect()

for k, v in depth_predict[condition].items(): 
    depth_predict[condition][k] = np.median(np.array(depth_predict[condition][k]), axis=0)
depth_predict[condition]['study_id'] = study_id
del study_id
gc.collect()
**Pipeline inference và ensemble** cho mô hình **Depth Detection**, với các mô hình và checkpiont khác nhau trên cả hai condition: `'scs'` và `'nfn'`.
Ensemble là **kỹ thuật kết hợp nhiều mô hình** lại với nhau để tạo ra **kết quả cuối cùng tốt hơn, ổn định hơn** so với chỉ dùng một mô hình duy nhất.
Tổng thể:
- **Duyệt qua từng loại mô hình (**`**scs**`**,** `**nfn**`**),** load từng checkpoint mô hình, chạy dự đoán trên toàn bộ test set.
- **Gộp (ensemble)** kết quả dự đoán từ nhiều mô hình khác nhau (nhiều checkpoint, nhiều fold) → lấy giá trị trung vị (median) của từng đốt sống cho mỗi study (patient).
- **Tối ưu cho đánh giá chính xác nhất trên leaderboard hoặc trong thực tế.**
  
**1. Khởi tạo dictionary lưu kết quả dự đoán**
```Python
depth_predict = {
		'scs': {'L1/L2':[], ...},
		'nfn': {'left_L1/L2':[], ..., 'right_L5/S1':[]}
}
  • Mỗi key sẽ lưu lại tất cả dự đoán từng mô hình, từng batch, từng study.
  • Giúp gộp lại cho bước ensemble. 2. **model_path_dict**
  • Dict chứa danh sách các checkpoint mô hình đã train cho từng condition (scs, nfn).
  • Các checkpoint này có thể là nhiều fold, nhiều config, nhiều loại (classification/regression). 3. Vòng lặp chính qua 2 condition: for condition in ['nfn', 'scs']: …. → Chạy lần lượt cho từng loại bệnh hoặc label khác nhau. 4. Lặp qua từng checkpoint mô hình Với mỗi mô hình đã train (checkpoint khác nhau), sẽ chạy dự đoán full test set. 5. Chuẩn bị dữ liệu test, xác định backbone và loại task:
dataset_test = DepthDetectDataset(_meta_df, condition, 'sub')
data_loader_test = DataLoader(dataset_test, batch_size=config["test_bs"], ...)
if '1024' in model_name: 
    widths = [128, 256, 512, 1024]
else: 
    widths = [64, 128, 256, 512]
if 'l1' in model_name: 
    model_type = 'regression'
else: 
    model_type = 'classification'
  • Chọn cấu trúc backbone phù hợp với mô hình (dựa vào tên file checkpoint).
  • Dùng đúng loại regression/classification tương ứng. 6. Load model từ checkpoint
model = DepthDetectModule.load_from_checkpoint(
    model_path, condition=condition, widths=widths, model_type=model_type
)
model.eval()
model.zero_grad()
model.to(device)
  • Load mô hình đã huấn luyện.
  • Đưa sang chế độ eval, chuyển lên GPU/CPU tùy cấu hình. 7. Vòng lặp inference từng batch
with torch.no_grad():
    for data in tqdm(data_loader_test, total=len(data_loader_test)):
        images, study_id = data
        images = images.to(device)
        preds = model.forward(images)
        if model_type == 'regression': 
            for k, v in preds.items(): 
                pred_temp[k].append((v[:, -1]*32).to('cpu').detach().numpy())
        else: 
            for k, v in preds.items(): 
                pred_temp[k].append(torch.argmax(v, dim=1).to('cpu').detach().numpy())
        study_id_list.append(study_id.to('cpu').reshape(-1).detach().numpy())
  • Nếu là regression, lấy giá trị cuối cùng trên trục class, nhân lại với depth chuẩn (32), convert về numpy.
  • Nếu classification, lấy index lớn nhất (class dự đoán).
  • Lưu lại dự đoán và study_id từng batch. 8. Gộp kết quả từng model
for k, v in pred_temp.items(): 
    depth_predict[condition][k].append(np.concatenate(v))
study_id = np.concatenate(study_id_list)
  • Sau mỗi mô hình, gộp kết quả các batch thành một mảng lớn cho mỗi đốt sống. 9. Ensemble kết quả nhiều model/checkpoint
for k, v in depth_predict[condition].items(): 
    depth_predict[condition][k] = np.median(np.array(depth_predict[condition][k]), axis=0)
depth_predict[condition]['study_id'] = study_id
  • Sau khi chạy hết tất cả checkpoint, tính giá trị trung vị (median) trên trục model cho mỗi đốt sống, mỗi study.
  • Kết quả là dự đoán đã được ensemble tối ưu. 10. Dọn bộ nhớ liên tục với **gc.collect()**

6. Create label coordinate & align

def create_label_ins(study_id, depth, level, condition, desc): 
    coor_dict = {'study_id': [], 'series_id': [], 'instance_number': []}
    _meta = meta_df.loc[meta_df.series_description==desc]
    for s, d in zip(study_id, depth): 
        sub_meta = _meta.loc[_meta.study_id==s]
        sub_meta = sub_meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
        if len(sub_meta) > 32: 
            d = (d/32)*len(sub_meta)
        try: 
            row = sub_meta.iloc[round(d)]
        except: 
            if condition == 'Spinal Canal Stenosis': 
                row = sub_meta.iloc[int(len(sub_meta)//2)]
            elif condition == 'Left Neural Foraminal Narrowing': 
                row = sub_meta.iloc[int(2*(len(sub_meta)//3))]
            elif condition == 'Right Neural Foraminal Narrowing': 
                row = sub_meta.iloc[int(len(sub_meta)//3)]
            print(s)
        coor_dict['study_id'].append(s)
        coor_dict['series_id'].append(row.series_id)
        coor_dict['instance_number'].append(row.instance_number)
    coor_dict['condition'] = condition
    coor_dict['level'] = level.split('_')[-1]
    return pd.DataFrame(coor_dict)

Từ một list study_id và dự đoán depth (tọa độ slice/lát), tìm lại thông tin DICOM slice tương ứng (series_id, instance_number, …) cho từng case. Dùng để:

  • join với label thật trong evaluate

  • submit lên hệ thống chấm điểm tự động

  • hiển thị hình ảnh lát cắt đúng trên giao diện

scs_study_id = depth_predict['scs']['study_id']
scs_coor_list = []
for k, v in depth_predict['scs'].items(): 
    if k != 'study_id': 
        scs_coor_list.append(create_label_ins(scs_study_id, v, k, 'Spinal Canal Stenosis', 'Sagittal T2/STIR'))
nfn_study_id = depth_predict['nfn']['study_id']
nfn_coor_list = []
for k, v in depth_predict['nfn'].items(): 
    if k != 'study_id': 
        if k.split('_')[0] == 'left': 
            condition = 'Left Neural Foraminal Narrowing'
        else: 
            condition = 'Right Neural Foraminal Narrowing'
        nfn_coor_list.append(create_label_ins(nfn_study_id, v, k, condition, 'Sagittal T1'))
scs_coor = pd.concat(scs_coor_list)
nfn_coor = pd.concat(nfn_coor_list)
pred_coor = pd.concat([scs_coor, nfn_coor]).sort_values(['study_id', 'series_id', 'level'])
  • Lấy kết quả dự đoán từ hai nhóm bệnh: SCS (Spinal Canal Stenosis)NFN (Neural Foraminal Narrowing).
  • Duyệt qua từng vị trí (level) của từng nhóm, sử dụng hàm create_label_ins để tìm lại thông tin lát ảnh thật trong DICOM ứng với dự đoán.
  • Gộp toàn bộ kết quả lại thành một bảng duy nhất, sắp xếp chuẩn bị cho các bước xử lý tiếp theo.
del scs_coor, nfn_coor, scs_coor_list, nfn_coor_list, depth_predict
gc.collect()
pred_coor.head()
pred_coor.to_csv('stage1_coor.csv', index=False)
  • Giải phóng RAM, tránh chiếm dụng không cần thiết (rất quan trọng khi làm việc với dữ liệu y tế lớn).
  • gc.collect(): ép Python thực hiện “garbage collection” ngay.
  • Hiển thị 5 dòng đầu của DataFrame pred_coor để kiểm tra lại dữ liệu trước khi lưu (optional, hay dùng trong notebook).
  • Ghi toàn bộ DataFrame pred_coor ra file CSV (không kèm chỉ số dòng)

II. Second Stage (xy inference)

Infer xy-coordinate of locations of sagittal t1 & t2

  • Mục tiêu: Với mỗi lát ảnh sagittal (T1 hoặc T2), mô hình dự đoán tọa độ (x, y) của vị trí cần xác định (ví dụ: vị trí tổn thương, vị trí giữa hai đốt sống,…). Ensemble or align (rule base)
  • Ensemble: Kết hợp nhiều mô hình hoặc nhiều phép thử lại để lấy kết quả ổn định hơn (trung bình, trung vị, voting…).
  • Align (rule base): Có thể dùng các quy tắc (rule-based) để điều chỉnh lại vị trí, ví dụ:
    • Nếu T1 và T2 khác biệt nhiều về tọa độ, chọn điểm gần trung tâm nhất.
    • Nếu kết quả nằm ngoài vùng hợp lệ của ảnh, ép vào vùng hợp lệ.
    • Nếu giữa hai lát (slice) dự đoán gần nhau, có thể lấy trung bình, hoặc chọn điểm đã có trong ground-truth.

1. Coordinate prediction dataset

Khởi tạo

class CoorDetectDataset(Dataset):
    def __init__(self, coor, meta, condition, usage='train'):
        if condition == 'scs':
            coor = coor.loc[coor.condition=='Spinal Canal Stenosis']
        elif condition == 'ss':
            coor = coor.loc[(coor.condition=='Left Subarticular Stenosis') | (coor.condition=='Right Subarticular Stenosis')]
        elif condition == 'nfn':
            coor = coor.loc[(coor.condition=='Right Neural Foraminal Narrowing') | (coor.condition=='Left Neural Foraminal Narrowing')]
        \#g_coor = coor.groupby('study_id').count()
        \#if condition == 'scs':
        #    self.id = g_coor.loc[g_coor.series_id==5].reset_index().study_id.unique()
        \#else:
        #    self.id = g_coor.loc[g_coor.series_id==10].reset_index().study_id.unique()
        self.id = coor.study_id.unique()
        self.coor = coor
        self.meta = meta
        self.condition = condition
        self.usage = usage
        if 3637444890 in self.id: 
            self.id.remove(3637444890)
        \#self.id = [2773343225]
        \#self.id = [1782095928]
        self.resize = v2.Resize((384, 384))
        
		def __getitem__(self, index):
        study_id = self.id[index]
        \#print(study_id)
        \#try:
        if self.condition == 'scs':
            volume = self.for_scs(study_id)
        elif self.condition == 'nfn':
            volume = self.for_nfn(study_id)
        if self.condition == 'ss':
            volume = self.for_ss(study_id)
        return volume, torch.tensor(study_id)

Mục tiêu chính:

  • Chọn ra các study phù hợp theo từng loại bệnh học (**condition**).
  • Chuẩn bị sample để huấn luyện hoặc inference: với mỗi study_id, trả về một volume ảnh đã tiền xử lý sẵn, cùng label hoặc thông tin kèm theo.
  • Dễ dàng mở rộng/biến đổi theo từng loại task (scs, ss, nfn).

1. Lọc dữ liệu theo condition

  • Lấy đúng các dòng label trong DataFrame coor thuộc nhóm bệnh cần train/infer. 2. Lấy danh sách study hợp lệ self.id = coor.study_id.unique()
  • Tạo mảng gồm tất cả study_id không trùng lặp cho bài toán. 3. Loại bỏ study lỗi hoặc không hợp lệ
  • Tránh lỗi khi train/infer (study này thiếu hoặc hỏng file). 4. Lưu các thuộc tính cần thiết
  • self.coor: label filtered
  • self.meta: metadata của DICOM (để truy xuất file, series, vị trí slice)
  • self.condition : condition
  • self.usage : usage
  • self.resize: transform resize ảnh về kích thước chuẩn

getitem

  • Dựa vào condition, gọi đúng hàm load volume (ảnh 3D) đã tiền xử lý:
    • for_scs(study_id) cho Spinal Canal Stenosis
    • for_nfn(study_id) cho Neural Foraminal Narrowing
    • for_ss(study_id) cho Subarticular Stenosis
  • Trả về:
    • volume: tensor ảnh 3D, đã chuẩn hóa kích thước (thường là [D, H, W])
    • study_id: tensor chứa id bệnh nhân, dùng để join kết quả hoặc debug.

Func for_scs

    def for_scs(self, study_id):
        meta = self.meta.loc[(self.meta.study_id==study_id) & (self.meta.series_description=='Sagittal T2/STIR')]
        meta = meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
        \#img = [self.normalize(self.load_dicom(f'/content/train_images/{row.study_id}/{row.series_id}/{row.instance_number}.dcm')) for _, row in meta.iterrows()]
        coor = self.coor.loc[(self.coor.study_id==study_id) & (self.coor.condition=='Spinal Canal Stenosis')]
        meta_list = []
        for _, row in coor.iterrows():
            series_id, instance_number = row.series_id, row.instance_number
            meta_list.append(meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)])
        sub_meta = pd.concat(meta_list)
        idx = meta.loc[meta.ipp_x == sub_meta.ipp_x.median()].index[0]
        \#print(old_idx)
        img_row = meta.iloc[idx]
        before_img_row = meta.iloc[idx-1]
        after_img_row = meta.iloc[idx+1]
        img = self.normalize(self.load_dicom(IMAGE_PATH + f'{img_row.study_id}/{img_row.series_id}/{img_row.instance_number}.dcm'))
        bimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{before_img_row.study_id}/{before_img_row.series_id}/{before_img_row.instance_number}.dcm'))
        aimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{after_img_row.study_id}/{after_img_row.series_id}/{after_img_row.instance_number}.dcm'))
        img = self.resize(torch.tensor(img[None, ...]))
        bimg = self.resize(torch.tensor(bimg[None, ...]))
        aimg = self.resize(torch.tensor(aimg[None, ...]))
        img = torch.cat([bimg, img, aimg]).to(torch.float32)
        return img

1. Lọc metadata của study SCS với đúng series

meta = self.meta.loc[
    (self.meta.study_id==study_id) &
    (self.meta.series_description=='Sagittal T2/STIR')
]
meta = meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
  • Lấy tất cả lát ảnh thuộc đúng study và đúng loại MRI (Sagittal T2/STIR).
  • Sắp xếp theo trục không gian (ipp_x) để đảm bảo đúng thứ tự lát.
  1. Tìm vị trí các lát liên quan đến tổn thương
coor = self.coor.loc[
    (self.coor.study_id==study_id) &
    (self.coor.condition=='Spinal Canal Stenosis')
]
meta_list = []
for _, row in coor.iterrows():
    series_id, instance_number = row.series_id, row.instance_number
    meta_list.append(meta.loc[
        (meta.series_id==series_id) & (meta.instance_number==instance_number)
    ])
sub_meta = pd.concat(meta_list)
  • Tìm trong bảng label (self.coor) tất cả các instance của study này với condition SCS.
  • Với từng label, tìm lại metadata của lát ảnh đó.
  • Ghép lại thành bảng con sub_meta – tập hợp các lát ảnh liên quan.
  1. Tìm slice trung tâm nhất trong vùng tổn thương
idx = meta.loc[meta.ipp_x == sub_meta.ipp_x.median()].index[0]
  • Chọn tọa độ ipp_x trung vị trong tập các lát liên quan (giữa vùng tổn thương) → idx là chỉ số lát trung tâm.
  1. Lấy 3 lát ảnh: trung tâm, trước và sau
img_row = meta.iloc[idx]
before_img_row = meta.iloc[idx-1]
after_img_row = meta.iloc[idx+1]
  • img_row: lát trung tâm
  • before_img_row: lát trước (idx-1)
  • after_img_row: lát sau (idx+1)
  1. Đọc và chuẩn hóa 3 lát ảnh
img = self.normalize(self.load_dicom(IMAGE_PATH + f'{img_row.study_id}/{img_row.series_id}/{img_row.instance_number}.dcm'))
bimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{before_img_row.study_id}/{before_img_row.series_id}/{before_img_row.instance_number}.dcm'))
aimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{after_img_row.study_id}/{after_img_row.series_id}/{after_img_row.instance_number}.dcm'))
  • Đọc từng file DICOM, chuẩn hóa giá trị pixel (theo hàm bạn định nghĩa).
  1. Resize các lát về kích thước chuẩn
img = self.resize(torch.tensor(img[None, ...]))
bimg = self.resize(torch.tensor(bimg[None, ...]))
aimg = self.resize(torch.tensor(aimg[None, ...]))
  • Đưa ảnh từ numpy → tensor, thêm chiều channel [1, H, W], resize về [1, 384, 384]
  1. Ghép 3 lát thành một tensor đầu vào
img = torch.cat([bimg, img, aimg]).to(torch.float32)
  • Output cuối: tensor shape [3, 384, 384] (3 channel: lát trước, trung tâm, sau).
  • Dùng làm input cho mô hình detect xy (mô hình sẽ nhìn cận cảnh hơn, tăng độ chính xác).
  1. Trả về ảnh return img
		def for_ss(self, study_id):
        meta = self.meta.loc[(self.meta.study_id==study_id) & (self.meta.series_description=='Axial T2')]
        meta = meta.sort_values('ipp_z', ascending=False).reset_index(drop=True)
        img = [self.normalize(self.load_dicom(f'/content/train_images/{row.study_id}/{row.series_id}/{row.instance_number}.dcm')) for _, row in meta.iterrows()]
        coor = self.coor.loc[(self.coor.study_id==study_id)]
        coor_dict = {}
        for _, row in coor.iterrows():
            series_id, instance_number = row.series_id, row.instance_number
            target_row = meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)]
            idx = target_row.index[0]
            \#print(row.level, idx, idx/len(img))
            \#plt.title(row.level)
            \#plt.imshow(img[idx])
            \#mask = torch.zeros(img[idx].shape)
            \#mask[int(row.y)-10:int((row.y))+10, int(row.x)-10:int((row.x))+10] = 1
            \#plt.imshow(mask, alpha=0.5)
            \#plt.show()
            height, width = img[idx].shape
            z = idx/depth if len(img) < depth else idx/len(img)
            x = row.x/width
            y = row.y/height
            if row.condition == 'Right Subarticular Stenosis':
                coor_dict['right_' + row.level] = torch.tensor([x, y, z]).to(torch.float32)
						else:
                coor_dict['left_' + row.level] = torch.tensor([x, y, z]).to(torch.float32)
        volume = torch.cat([self.resize(torch.tensor(i)[None, ...]).to(torch.float32) for i in img]).contiguous()
        if volume.shape[0] < depth:
            volume = torch.cat([volume, torch.zeros(depth-volume.shape[0], volume.shape[1], volume.shape[2])])
        elif volume.shape[0] > depth:
            volume = torch.nn.functional.interpolate(volume[None, None, ...], (depth, volume.shape[1], volume.shape[2])).squeeze()
        return volume, coor_dict
def for_nfn(self, study_id):
        meta = self.meta.loc[(self.meta.study_id==study_id) & (self.meta.series_description=='Sagittal T1')]
        meta = meta.sort_values('ipp_x', ascending=True).reset_index(drop=True)
        \#img = [self.normalize(self.load_dicom(f'/content/train_images/{row.study_id}/{row.series_id}/{row.instance_number}.dcm')) for _, row in meta.iterrows()]
        coor = self.coor.loc[(self.coor.study_id==study_id)]
        right_meta_list = []
        left_meta_list = []
        for _, row in coor.iterrows():
            series_id, instance_number = row.series_id, row.instance_number
            if row.condition == 'Right Neural Foraminal Narrowing':
                right_meta_list.append(meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)])
            else: 
                left_meta_list.append(meta.loc[(meta.series_id==series_id) & (meta.instance_number==instance_number)])
        right_sub_meta = pd.concat(right_meta_list)
        left_sub_meta = pd.concat(left_meta_list)
        ridx = meta.loc[meta.ipp_x == right_sub_meta.ipp_x.median()].index[0]
        lidx = meta.loc[meta.ipp_x == left_sub_meta.ipp_x.median()].index[0]
        right_img_row = meta.iloc[min(max(ridx, 0), len(meta)-1)]
        \#display(right_img_row)
        right_before_img_row = meta.iloc[min(max(ridx-1, 0), len(meta)-1)]
        rightafter_img_row = meta.iloc[min(max(ridx+1, 0), len(meta)-1)]
        left_img_row = meta.iloc[min(max(lidx, 0), len(meta)-1)]
        left_before_img_row = meta.iloc[min(max(lidx-1, 0), len(meta)-1)]
        leftafter_img_row = meta.iloc[min(max(lidx+1, 0), len(meta)-1)]
        rimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{right_img_row.study_id}/{right_img_row.series_id}/{right_img_row.instance_number}.dcm'))
        rbimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{right_before_img_row.study_id}/{right_before_img_row.series_id}/{right_before_img_row.instance_number}.dcm'))
        raimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{rightafter_img_row.study_id}/{rightafter_img_row.series_id}/{rightafter_img_row.instance_number}.dcm'))
        limg = self.normalize(self.load_dicom(IMAGE_PATH + f'{left_img_row.study_id}/{left_img_row.series_id}/{left_img_row.instance_number}.dcm'))
        lbimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{left_before_img_row.study_id}/{left_before_img_row.series_id}/{left_before_img_row.instance_number}.dcm'))
        laimg = self.normalize(self.load_dicom(IMAGE_PATH + f'{leftafter_img_row.study_id}/{leftafter_img_row.series_id}/{leftafter_img_row.instance_number}.dcm'))
        rimg = torch.cat([self.resize(torch.tensor(i)[None, ...]).to(torch.float32) for i in [rbimg, rimg, raimg]])
        limg = torch.cat([self.resize(torch.tensor(i)[None, ...]).to(torch.float32) for i in [lbimg, limg, laimg]])
        img = torch.stack([limg, rimg]).to(torch.float32).contiguous()
        return img
    def normalize(self, x):
        lower, upper = np.percentile(x, (1, 99))
        x = np.clip(x, lower, upper)
        x = x - np.min(x)
        x = x / np.max(x)
        return x
    def __len__(self):
        return len(self.id)
    def load_dicom(self, path):
        dicom = dcm.read_file(path)
        data = dicom.pixel_array
        return data

2. Coordinate prediction models

class ConvNextSCSDetect(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        \#self.size = 384
        if encoder == 'convnext': 
            self.encoder = timm.create_model('convnext_base.fb_in22k_ft_in1k_384', in_chans=3, pretrained=False, num_classes=0)
        elif encoder == 'efficientnetv2-l': 
            self.encoder = timm.create_model('tf_efficientnetv2_l.in21k_ft_in1k', in_chans=3, pretrained=False, num_classes=0, drop_rate=0.)
        self.in_features = self.encoder.num_features
        self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
                                    nn.Flatten(1),
                                    \#nn.LayerNorm(self.in_features)
                                    )
        self.l1 = nn.Linear(self.in_features, 2)
        self.l2 = nn.Linear(self.in_features, 2)
        self.l3 = nn.Linear(self.in_features, 2)
        self.l4 = nn.Linear(self.in_features, 2)
        self.l5 = nn.Linear(self.in_features, 2)
		def forward(self, x, label=None):
        \#for loc, img in x.items():
            \#print(img.shape)
        #    img = self.encoder.forward_features(img)
        #    img = self.flatten(img)
        #    x[loc] = img
        x = self.encoder.forward_features(x)
        x = self.flatten(x)
        l1 = self.l1(x)
        l2 = self.l2(x)
        l3 = self.l3(x)
        l4 = self.l4(x)
        l5 = self.l5(x)
        return {'L1/L2': l1.sigmoid(), 'L2/L3': l2.sigmoid(), 'L3/L4': l3.sigmoid(), 'L4/L5': l4.sigmoid(), 'L5/S1': l5.sigmoid()}
class ConvNextNFNDetect(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        if encoder == 'convnext': 
            self.encoder = timm.create_model('convnext_base.fb_in22k_ft_in1k_384', in_chans=3, pretrained=False, num_classes=0)
        elif encoder == 'efficientnetv2-l': 
            self.encoder = timm.create_model('tf_efficientnetv2_l.in21k_ft_in1k', in_chans=3, pretrained=False, num_classes=0, drop_rate=0.)
        self.in_features = self.encoder.num_features
        self.flatten = nn.Sequential(nn.AdaptiveAvgPool2d((1,1)),
                                    nn.Flatten(1),
                                    \#nn.LayerNorm(self.in_features)
                                    )
        self.ll1 = nn.Linear(self.in_features, 2)
        self.ll2 = nn.Linear(self.in_features, 2)
        self.ll3 = nn.Linear(self.in_features, 2)
        self.ll4 = nn.Linear(self.in_features, 2)
        self.ll5 = nn.Linear(self.in_features, 2)
        self.rl1 = nn.Linear(self.in_features, 2)
        self.rl2 = nn.Linear(self.in_features, 2)
        self.rl3 = nn.Linear(self.in_features, 2)
        self.rl4 = nn.Linear(self.in_features, 2)
        self.rl5 = nn.Linear(self.in_features, 2)
    
		def forward(self, x, label=None):
        shape = x.shape
        x = x.reshape(shape[0]*shape[1], 3, shape[-2], shape[-1])
        x = self.encoder.forward_features(x)
        x = self.flatten(x)
        x = x.reshape(shape[0], shape[1], -1)
        x_left = x[:, 0, :]
        x_right = x[:, 1, :]
        ll1 = self.ll1(x_left)
        ll2 = self.ll2(x_left)
        ll3 = self.ll3(x_left)
        ll4 = self.ll4(x_left)
        ll5 = self.ll5(x_left)
        rl1 = self.rl1(x_right)
        rl2 = self.rl2(x_right)
        rl3 = self.rl3(x_right)
        rl4 = self.rl4(x_right)
        rl5 = self.rl5(x_right)
        return {'left_L1/L2': ll1.sigmoid(),'left_L2/L3': ll2.sigmoid(),'left_L3/L4': ll3.sigmoid(), 'left_L4/L5': ll4.sigmoid(), 'left_L5/S1': ll5.sigmoid(),
                'right_L1/L2': rl1.sigmoid(), 'right_L2/L3': rl2.sigmoid(), 'right_L3/L4': rl3.sigmoid(), 'right_L4/L5': rl4.sigmoid(), 'right_L5/S1': rl5.sigmoid()}

3. Coordinate detection lightning module

class DetectModule(pl.LightningModule):
    def __init__(self, condition, encoder):
        super().__init__()
        self.config = condition
        if condition == 'scs':
            self.model = ConvNextSCSDetect(encoder)
        elif condition == 'nfn':
            self.model = ConvNextNFNDetect(encoder)
        elif  condition == 'ss': 
            pass
        \#self.ema = ExponentialMovingAverage(self.model.parameters(), decay=0.995)
        \#self.ema.to(device)
        \#self.model = torch.optim.swa_utils.AveragedModel(self.model,
        #                                                 multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
    def forward(self, batch):
        preds = self.model(batch)
        return preds

4. Coordinate inference

time
def create_label_coor(study_id, coor_df, coor, level, condition, desc): 
    _meta = meta_df.loc[meta_df.series_description==desc].copy()
    _coor = coor_df.loc[coor_df.condition == condition]
    _coor_df = {'study_id': [], 'series_id': [], 'x': [], 'y': []}
    for s, c in zip(study_id, coor): 
        sub_meta = _meta.loc[_meta.study_id == s]
        sub_coor = _coor.loc[(_coor.study_id==s) & (_coor.level==level.split('_')[-1])].squeeze(axis=0)
        \#display(sub_coor)
        meta_row = sub_meta.loc[(sub_meta.instance_number==sub_coor.instance_number) & (sub_meta.series_id==sub_coor.series_id)].squeeze(axis=0)
        x = round(meta_row.width * c[0])
        y = round(meta_row.height * c[1])
        _coor_df['study_id'].append(s)
        _coor_df['series_id'].append(sub_coor.series_id)
        _coor_df['x'].append(x)
        _coor_df['y'].append(y)
    _coor_df['level'] = level.split('_')[-1]
    _coor_df['condition'] = condition
    del _meta, _coor, sub_meta, sub_coor, meta_row
    return pd.DataFrame(_coor_df)
scs_study_id = coor_predict['scs']['study_id']
scs_coor_list = []
for k, v in coor_predict['scs'].items(): 
    if k != 'study_id': 
        scs_coor_list.append(create_label_coor(scs_study_id, pred_coor, v, k, 'Spinal Canal Stenosis', 'Sagittal T2/STIR'))
nfn_study_id = coor_predict['nfn']['study_id']
nfn_coor_list = []
for k, v in coor_predict['nfn'].items(): 
    if k != 'study_id': 
        if k.split('_')[0] == 'left': 
            condition = 'Left Neural Foraminal Narrowing'
        else: 
            condition = 'Right Neural Foraminal Narrowing'
        nfn_coor_list.append(create_label_coor(nfn_study_id, pred_coor, v, k, condition, 'Sagittal T1'))
scs_coor = pd.concat(scs_coor_list)
nfn_coor = pd.concat(nfn_coor_list)
_pred_coor = pd.concat([scs_coor, nfn_coor]).sort_values(['study_id', 'series_id', 'level'])
pred_coor_stage2 = pd.merge(pred_coor, _pred_coor, on=['study_id', 'series_id', 'level', 'condition'], how='inner')
display(pred_coor_stage2.head())
pred_coor_stage2.to_csv('stage2_coor.csv', index=False)
study_idseries_idinstance_numberconditionlevelxy
04403693928282038459Left Neural Foraminal NarrowingL1/L2387201
144036939282820384518Right Neural Foraminal NarrowingL1/L2384179
24403693928282038459Left Neural Foraminal NarrowingL2/L3348267
344036939282820384518Right Neural Foraminal NarrowingL2/L3350243
44403693928282038459Left Neural Foraminal NarrowingL3/L4311326

III. Third Stage (calc. location of axial t2)

  • calcurate depth of axial t2 for each location roughly, using xyz-coordinate (refered to @hengck’s transformation from sagittal t2 to axial t2)
  • roughly separate each locations
  • infer instance number
  • infer xy-coordinate

1. Calculate axial slice

# project 2d to 3d
def project_to_3d(row):
    sx, sy, sz = row.ipp_x, row.ipp_y, row.ipp_z
    x, y = row.x, row.y
    o0, o1, o2, o3, o4, o5 = row.iop
    delx, dely = row.ps_x, row.ps_y
    xx = o0 * delx * x + o3 * dely * y + sx
    yy = o1 * delx * x + o4 * dely * y + sy
    zz = o2 * delx * x + o5 * dely * y + sz
    return xx,yy,zz
def sag_to_ax(sub_coor, sub_meta): 
    point = sub_coor[['ipp_x', 'ipp_y', 'ipp_z']].values \#2d
    level_list = sub_coor.level.tolist()    
    # here we project 2d to 3d
    center=[] 
    for _, row in sub_coor.iterrows():
        xx,yy,zz = project_to_3d(row)
        center.append([xx,yy,zz])
    center = np.array(center) \#3d
    # == 2. we get closest axial slices to the CSC points =================
    \#df = valid_data[0].axial_t2[0].df
    orientation = np.array(sub_meta.iop.values.tolist())
    position= np.array(sub_meta[['ipp_x', 'ipp_y', 'ipp_z']].values.tolist())
    ox = orientation[:, :3]
    oy = orientation[:, 3:]
    oz = np.cross(ox,oy)
    t = center.reshape(-1,1,3) - position.reshape(1,-1,3)
    dis = (oz.reshape(1,-1,3) * t).sum(-1)  # np.dot(point-s,oz)
    dis = np.fabs(dis)
    closest = dis.argmin(-1)
    closest_df = sub_meta.iloc[closest]
    closest_df['level'] = level_list#['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']
    closest_df['x'] = 0
    closest_df['y'] = 0
    \#closest_df = pd.concat([closest_df, closest_df])
    \#closest_df['condition'] = ['Left Subarticular Stenosis']*5 + ['Right Subarticular Stenosis']*5
    return closest_df[['study_id', 'series_id', 'instance_number', 'level']]
# sagittal t2 => axial t2
scs_coor = pred_coor_stage2.loc[pred_coor_stage2.condition=='Spinal Canal Stenosis'].copy()
scs_coor = scs_coor.merge(meta_df, on=['study_id', 'series_id', 'instance_number'], how='left')
study_id = scs_coor.study_id.unique()
ax_meta  = meta_df.loc[(meta_df.series_description=='Axial T2')]
closest_ax_list = []
for s in tqdm(study_id, total=len(study_id)): 
    sub_coor = scs_coor.loc[scs_coor.study_id==s]
    sub_meta = ax_meta.loc[ax_meta.study_id==s]
    closest_ax_list.append(sag_to_ax(sub_coor, sub_meta)) 
closest_ax = pd.concat(closest_ax_list)
closest_ax.head()
study_idseries_idinstance_numberlevel
1644036939348197151817
2244036939348197151823
2744036939348197151828
3344036939348197151834
3844036939348197151839

2. Subarticular stenosis coordinate prediction dataset

class SSDetectDataset(Dataset):
    def __init__(self, ax, usage='train'):
        self.ax = ax
        self.id = ax.study_id.unique()
        self.usage = usage
        self.id = list(set(self.id) - set([3637444890]))
        \#self.id = [2773343225]
        \#self.id = [1782095928]
        self.resize = v2.Resize((384, 384))
        
    def __getitem__(self, index):
        study_id = self.id[index]
        volume = self.for_ss(study_id)
        return volume, torch.tensor(study_id)
   def for_ss(self, study_id):
        ax = self.ax.loc[self.ax.study_id==study_id]
        img_dict = {}
        for _, row in ax.iterrows():
            series_id, instance_number = row.series_id, row.instance_number
            img = self.load_dicom(IMAGE_PATH + f'{study_id}/{series_id}/{instance_number}.dcm').astype(np.float32)
            img = self.resize(torch.tensor(img)[None, ...])
            img = self.normalize(img)
            img_dict[row.level] = img
        img_list = []
        for k in ['L1/L2', 'L2/L3', 'L3/L4', 'L4/L5', 'L5/S1']: 
            img_list.append(img_dict[k])
        volume = torch.stack(img_list).contiguous()
        return volume
    def normalize(self, x):
        upper = torch.quantile(x, torch.tensor([0.99]))
        lower = torch.quantile(x, torch.tensor([0.01]))
        x = torch.clip(x, lower, upper)
        x = x - torch.min(x)
        x = x / (torch.max(x)+1e-6)
        return x