优化mode 1,2

This commit is contained in:
leo
2024-03-07 23:04:27 +08:00
parent dae4dd0d9f
commit da51178211

58
main.py
View File

@@ -32,50 +32,52 @@ class DataRow(BaseModel):
update_cache: bool = False update_cache: bool = False
def make_train_dicts(conditions: tuple): def make_train_dicts(with_entities: tuple, conditions: tuple):
return [ return [
DataRow(path=_.ann_file_lbs, DataRow(path=_.ann_file,
img_prefix=_.img_prefix).dict() img_prefix=_.img_prefix).dict()
for _ in session.query(BatchData).where(and_(*conditions)).all() for _ in session.query(BatchData)
.with_entities(*with_entities)
.where(and_(*conditions)).all()
] ]
def make_train_data(mode_id) -> dict[str, list]: def prepare_train_data(mode_id) -> dict[str, list]:
res = { res = {
"train_data": [], "train_data": [],
"val_data": [], "val_data": [],
} }
if mode_id == 1: if mode_id == 1:
res['train_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
conditions=(BatchData.is_train == 1, BatchData.ann_file_lbs.is_not(None))
)
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix, BatchData.ann_file_lbs.label('ann_file')),
conditions=(BatchData.is_validation == 1, BatchData.ann_file_lbs.is_not(None))
)
elif mode_id == 2:
res['train_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file')),
conditions=(BatchData.is_train == 1,)
)
res['val_data'] = make_train_dicts(
with_entities=(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file')),
conditions=(BatchData.is_validation == 1,)
)
elif mode_id == 3:
res['train_data'] = make_train_dicts(conditions=( res['train_data'] = make_train_dicts(conditions=(
BatchData.is_train == 1, BatchData.is_train == 1,
BatchData.ann_file_lbs.is_not(None) BatchData.ann_file.is_not(None)
)) ))
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all() val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
img_prefix=db_obj.img_prefix).dict())
elif mode_id == 2:
results = (session.query(BatchData)
.with_entities(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
.where(BatchData.is_train == 1).all())
res['train_data'] = [
DataRow(path=_.ann_file,
img_prefix=_.img_prefix).dict()
for _ in results
]
val_db_objs = (session.query(BatchData)
.with_entities(BatchData.img_prefix,
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
else_=BatchData.ann_file).label('ann_file'))
.where(BatchData.is_validation == 1).all())
for db_obj in val_db_objs: for db_obj in val_db_objs:
res['val_data'].append(DataRow(path=db_obj.ann_file, res['val_data'].append(DataRow(path=db_obj.ann_file,
img_prefix=db_obj.img_prefix).dict()) img_prefix=db_obj.img_prefix).dict())
return res return res
@@ -108,7 +110,7 @@ left_col, right_col = st.columns([3, 1])
def train(): def train():
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')}) right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
right_col.json(make_train_data(st.session_state.mode)) right_col.json(prepare_train_data(st.session_state.mode))
def table_update_handler(): def table_update_handler():