diff --git a/main.py b/main.py index b08c4a4..4a2e9b4 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import pandas as pd import streamlit as st from loguru import logger from pydantic import BaseModel -from sqlalchemy import and_ +from sqlalchemy import and_, case from bak.init_data import BatchDataCreate, BatchDataRead from db_utils import BatchData, session @@ -32,25 +32,50 @@ class DataRow(BaseModel): update_cache: bool = False +def make_train_dicts(conditions: tuple): + return [ + DataRow(path=_.ann_file_lbs, + img_prefix=_.img_prefix).dict() + for _ in session.query(BatchData).where(and_(*conditions)).all() + ] + + def make_train_data(mode_id) -> dict[str, list]: res = { "train_data": [], "val_data": [], } if mode_id == 1: - for db_obj in session.query(BatchData).where(and_( - BatchData.is_train == 1, - BatchData.ann_file_lbs.is_not(None) - )).all(): - res['train_data'].append(DataRow(path=db_obj.ann_file_lbs, - img_prefix=db_obj.img_prefix).dict()) + res['train_data'] = make_train_dicts(conditions=( + BatchData.is_train == 1, + BatchData.ann_file_lbs.is_not(None) + )) + val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() + for db_obj in val_db_objs: + res['val_data'].append(DataRow(path=db_obj.ann_file_lbs, + img_prefix=db_obj.img_prefix).dict()) + elif mode_id == 2: + results = (session.query(BatchData) + .with_entities(BatchData.img_prefix, + case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), + else_=BatchData.ann_file).label('ann_file')) + .where(BatchData.is_train == 1).all()) + res['train_data'] = [ + DataRow(path=_.ann_file, + img_prefix=_.img_prefix).dict() + for _ in results + ] + val_db_objs = (session.query(BatchData) + .with_entities(BatchData.img_prefix, + case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs), + else_=BatchData.ann_file).label('ann_file')) + .where(BatchData.is_validation == 1).all()) + for db_obj in val_db_objs: + res['val_data'].append(DataRow(path=db_obj.ann_file, + img_prefix=db_obj.img_prefix).dict()) - val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() - for db_obj in val_db_objs: - res['val_data'].append(DataRow(path=db_obj.ann_file_lbs, - img_prefix=db_obj.img_prefix).dict()) return res @@ -71,7 +96,7 @@ st.session_state.setdefault('evolve_r', 0.02) st.session_state.setdefault('n_trail', 10) st.session_state.setdefault('n_epoch', 1) st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0]) -st.session_state.setdefault('mode', 1) +st.session_state.setdefault('mode', 2) st.session_state.setdefault('configs', { k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') })