实现mode 2
This commit is contained in:
43
main.py
43
main.py
@@ -5,7 +5,7 @@ import pandas as pd
|
|||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import and_
|
from sqlalchemy import and_, case
|
||||||
|
|
||||||
from bak.init_data import BatchDataCreate, BatchDataRead
|
from bak.init_data import BatchDataCreate, BatchDataRead
|
||||||
from db_utils import BatchData, session
|
from db_utils import BatchData, session
|
||||||
@@ -32,25 +32,50 @@ class DataRow(BaseModel):
|
|||||||
update_cache: bool = False
|
update_cache: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_dicts(conditions: tuple):
|
||||||
|
return [
|
||||||
|
DataRow(path=_.ann_file_lbs,
|
||||||
|
img_prefix=_.img_prefix).dict()
|
||||||
|
for _ in session.query(BatchData).where(and_(*conditions)).all()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_train_data(mode_id) -> dict[str, list]:
|
def make_train_data(mode_id) -> dict[str, list]:
|
||||||
res = {
|
res = {
|
||||||
"train_data": [],
|
"train_data": [],
|
||||||
"val_data": [],
|
"val_data": [],
|
||||||
}
|
}
|
||||||
if mode_id == 1:
|
if mode_id == 1:
|
||||||
for db_obj in session.query(BatchData).where(and_(
|
res['train_data'] = make_train_dicts(conditions=(
|
||||||
BatchData.is_train == 1,
|
BatchData.is_train == 1,
|
||||||
BatchData.ann_file_lbs.is_not(None)
|
BatchData.ann_file_lbs.is_not(None)
|
||||||
)).all():
|
))
|
||||||
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
|
||||||
img_prefix=db_obj.img_prefix).dict())
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
||||||
for db_obj in val_db_objs:
|
for db_obj in val_db_objs:
|
||||||
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||||
img_prefix=db_obj.img_prefix).dict())
|
img_prefix=db_obj.img_prefix).dict())
|
||||||
|
|
||||||
|
elif mode_id == 2:
|
||||||
|
results = (session.query(BatchData)
|
||||||
|
.with_entities(BatchData.img_prefix,
|
||||||
|
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||||
|
else_=BatchData.ann_file).label('ann_file'))
|
||||||
|
.where(BatchData.is_train == 1).all())
|
||||||
|
|
||||||
|
res['train_data'] = [
|
||||||
|
DataRow(path=_.ann_file,
|
||||||
|
img_prefix=_.img_prefix).dict()
|
||||||
|
for _ in results
|
||||||
|
]
|
||||||
|
val_db_objs = (session.query(BatchData)
|
||||||
|
.with_entities(BatchData.img_prefix,
|
||||||
|
case((BatchData.ann_file_lbs.is_not(None), BatchData.ann_file_lbs),
|
||||||
|
else_=BatchData.ann_file).label('ann_file'))
|
||||||
|
.where(BatchData.is_validation == 1).all())
|
||||||
|
for db_obj in val_db_objs:
|
||||||
|
res['val_data'].append(DataRow(path=db_obj.ann_file,
|
||||||
|
img_prefix=db_obj.img_prefix).dict())
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
@@ -71,7 +96,7 @@ st.session_state.setdefault('evolve_r', 0.02)
|
|||||||
st.session_state.setdefault('n_trail', 10)
|
st.session_state.setdefault('n_trail', 10)
|
||||||
st.session_state.setdefault('n_epoch', 1)
|
st.session_state.setdefault('n_epoch', 1)
|
||||||
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
||||||
st.session_state.setdefault('mode', 1)
|
st.session_state.setdefault('mode', 2)
|
||||||
st.session_state.setdefault('configs', {
|
st.session_state.setdefault('configs', {
|
||||||
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user