实现mode 2

This commit is contained in:
leo
2024-03-07 22:52:08 +08:00
parent bf2833e590
commit dae4dd0d9f

49
main.py
View File

@@ -5,7 +5,7 @@ import pandas as pd
import streamlit as st
from loguru import logger
from pydantic import BaseModel
from sqlalchemy import and_
from sqlalchemy import and_, case
from bak.init_data import BatchDataCreate, BatchDataRead
from db_utils import BatchData, session
@@ -32,25 +32,50 @@ class DataRow(BaseModel):
update_cache: bool = False
def make_train_dicts(conditions: tuple):
return [
DataRow(path=_.ann_file_lbs,
img_prefix=_.img_prefix).dict()
for _ in session.query(BatchData).where(and_(*conditions)).all()
]
def make_train_data(mode_id) -> dict[str, list]:
res = {
"train_data": [],
"val_data": [],
}
if mode_id == 1:
for db_obj in session.query(BatchData).where(and_(
BatchData.is_train == 1,
BatchData.ann_file_lbs.is_not(None)
)).all():
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
res['train_data'] = make_train_dicts(conditions=(
BatchData.is_train == 1,
BatchData.ann_file_lbs.is_not(None)
))
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
elif mode_id == 2:
results = (session.query(BatchData)
.with_entities(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
.where(BatchData.is_train == 1).all())
res['train_data'] = [
DataRow(path=_.ann_file,
img_prefix=_.img_prefix).dict()
for _ in results
]
val_db_objs = (session.query(BatchData)
.with_entities(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
.where(BatchData.is_validation == 1).all())
for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file,
img_prefix=db_obj.img_prefix).dict())
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
return res
@@ -71,7 +96,7 @@ st.session_state.setdefault('evolve_r', 0.02)
st.session_state.setdefault('n_trail', 10)
st.session_state.setdefault('n_epoch', 1)
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
st.session_state.setdefault('mode', 1)
st.session_state.setdefault('mode', 2)
st.session_state.setdefault('configs', {
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
})