实现mode 2

This commit is contained in:
leo
2024-03-07 22:52:08 +08:00
parent bf2833e590
commit dae4dd0d9f

43
main.py
View File

@@ -5,7 +5,7 @@ import pandas as pd
import streamlit as st import streamlit as st
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import and_ from sqlalchemy import and_, case
from bak.init_data import BatchDataCreate, BatchDataRead from bak.init_data import BatchDataCreate, BatchDataRead
from db_utils import BatchData, session from db_utils import BatchData, session
@@ -32,25 +32,50 @@ class DataRow(BaseModel):
update_cache: bool = False update_cache: bool = False
def make_train_dicts(conditions: tuple):
return [
DataRow(path=_.ann_file_lbs,
img_prefix=_.img_prefix).dict()
for _ in session.query(BatchData).where(and_(*conditions)).all()
]
def make_train_data(mode_id) -> dict[str, list]: def make_train_data(mode_id) -> dict[str, list]:
res = { res = {
"train_data": [], "train_data": [],
"val_data": [], "val_data": [],
} }
if mode_id == 1: if mode_id == 1:
for db_obj in session.query(BatchData).where(and_( res['train_data'] = make_train_dicts(conditions=(
BatchData.is_train == 1, BatchData.is_train == 1,
BatchData.ann_file_lbs.is_not(None) BatchData.ann_file_lbs.is_not(None)
)).all(): ))
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
for db_obj in val_db_objs: for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs, res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict()) img_prefix=db_obj.img_prefix).dict())
elif mode_id == 2:
results = (session.query(BatchData)
.with_entities(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
.where(BatchData.is_train == 1).all())
res['train_data'] = [
DataRow(path=_.ann_file,
img_prefix=_.img_prefix).dict()
for _ in results
]
val_db_objs = (session.query(BatchData)
.with_entities(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
.where(BatchData.is_validation == 1).all())
for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file,
img_prefix=db_obj.img_prefix).dict())
return res return res
@@ -71,7 +96,7 @@ st.session_state.setdefault('evolve_r', 0.02)
st.session_state.setdefault('n_trail', 10) st.session_state.setdefault('n_trail', 10)
st.session_state.setdefault('n_epoch', 1) st.session_state.setdefault('n_epoch', 1)
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0]) st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
st.session_state.setdefault('mode', 1) st.session_state.setdefault('mode', 2)
st.session_state.setdefault('configs', { st.session_state.setdefault('configs', {
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode') k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
}) })