实现mode 2
This commit is contained in:
49
main.py
49
main.py
@@ -5,7 +5,7 @@ import pandas as pd
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import and_, case
|
||||
|
||||
from bak.init_data import BatchDataCreate, BatchDataRead
|
||||
from db_utils import BatchData, session
|
||||
@@ -32,25 +32,50 @@ class DataRow(BaseModel):
|
||||
update_cache: bool = False
|
||||
|
||||
|
||||
def make_train_dicts(conditions: tuple):
|
||||
return [
|
||||
DataRow(path=_.ann_file_lbs,
|
||||
img_prefix=_.img_prefix).dict()
|
||||
for _ in session.query(BatchData).where(and_(*conditions)).all()
|
||||
]
|
||||
|
||||
|
||||
def make_train_data(mode_id) -> dict[str, list]:
|
||||
res = {
|
||||
"train_data": [],
|
||||
"val_data": [],
|
||||
}
|
||||
if mode_id == 1:
|
||||
for db_obj in session.query(BatchData).where(and_(
|
||||
BatchData.is_train == 1,
|
||||
BatchData.ann_file_lbs.is_not(None)
|
||||
)).all():
|
||||
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||
img_prefix=db_obj.img_prefix).dict())
|
||||
res['train_data'] = make_train_dicts(conditions=(
|
||||
BatchData.is_train == 1,
|
||||
BatchData.ann_file_lbs.is_not(None)
|
||||
))
|
||||
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
||||
for db_obj in val_db_objs:
|
||||
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||
img_prefix=db_obj.img_prefix).dict())
|
||||
|
||||
elif mode_id == 2:
|
||||
results = (session.query(BatchData)
|
||||
.with_entities(BatchData.img_prefix,
|
||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||
else_=BatchData.ann_file).label('ann_file'))
|
||||
.where(BatchData.is_train == 1).all())
|
||||
|
||||
res['train_data'] = [
|
||||
DataRow(path=_.ann_file,
|
||||
img_prefix=_.img_prefix).dict()
|
||||
for _ in results
|
||||
]
|
||||
val_db_objs = (session.query(BatchData)
|
||||
.with_entities(BatchData.img_prefix,
|
||||
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||
else_=BatchData.ann_file).label('ann_file'))
|
||||
.where(BatchData.is_validation == 1).all())
|
||||
for db_obj in val_db_objs:
|
||||
res['val_data'].append(DataRow(path=db_obj.ann_file,
|
||||
img_prefix=db_obj.img_prefix).dict())
|
||||
|
||||
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
||||
for db_obj in val_db_objs:
|
||||
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||
img_prefix=db_obj.img_prefix).dict())
|
||||
return res
|
||||
|
||||
|
||||
@@ -71,7 +96,7 @@ st.session_state.setdefault('evolve_r', 0.02)
|
||||
st.session_state.setdefault('n_trail', 10)
|
||||
st.session_state.setdefault('n_epoch', 1)
|
||||
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
||||
st.session_state.setdefault('mode', 1)
|
||||
st.session_state.setdefault('mode', 2)
|
||||
st.session_state.setdefault('configs', {
|
||||
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user