实现mode 1
This commit is contained in:
@@ -27,3 +27,4 @@ lbs_path 不生成
|
|||||||
lbs模式: select(train==1, path=lbs_path)
|
lbs模式: select(train==1, path=lbs_path)
|
||||||
lbs优先模式: select(train==1, path=lbs_path if lbs_path else path)
|
lbs优先模式: select(train==1, path=lbs_path if lbs_path else path)
|
||||||
不使用lbs模式: select(train==1, path=path)
|
不使用lbs模式: select(train==1, path=path)
|
||||||
|
|
||||||
|
|||||||
38
main.py
38
main.py
@@ -4,6 +4,8 @@ import os
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import and_
|
||||||
|
|
||||||
from bak.init_data import BatchDataCreate, BatchDataRead
|
from bak.init_data import BatchDataCreate, BatchDataRead
|
||||||
from db_utils import BatchData, session
|
from db_utils import BatchData, session
|
||||||
@@ -23,6 +25,35 @@ st.set_page_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataRow(BaseModel):
|
||||||
|
path: str | None = None
|
||||||
|
img_prefix: str | None = None
|
||||||
|
filter_empty_gt: bool = False
|
||||||
|
update_cache: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_data(mode_id) -> dict[str, list]:
|
||||||
|
res = {
|
||||||
|
"train_data": [],
|
||||||
|
"val_data": [],
|
||||||
|
}
|
||||||
|
if mode_id == 1:
|
||||||
|
for db_obj in session.query(BatchData).where(and_(
|
||||||
|
BatchData.is_train == 1,
|
||||||
|
BatchData.ann_file_lbs.is_not(None)
|
||||||
|
)).all():
|
||||||
|
res['train_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||||
|
img_prefix=db_obj.img_prefix).dict())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
val_db_objs = session.query(BatchData).where(BatchData.is_validation == 1).all()
|
||||||
|
for db_obj in val_db_objs:
|
||||||
|
res['val_data'].append(DataRow(path=db_obj.ann_file_lbs,
|
||||||
|
img_prefix=db_obj.img_prefix).dict())
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
@st.cache_data
|
@st.cache_data
|
||||||
def list_ckpt_paths(dir_path):
|
def list_ckpt_paths(dir_path):
|
||||||
return os.listdir(dir_path)
|
return os.listdir(dir_path)
|
||||||
@@ -36,11 +67,11 @@ def get_data_from_db():
|
|||||||
|
|
||||||
st.session_state.setdefault('data_table', [])
|
st.session_state.setdefault('data_table', [])
|
||||||
st.session_state.setdefault('username', '')
|
st.session_state.setdefault('username', '')
|
||||||
st.session_state.setdefault('evolve_r', 0.2)
|
st.session_state.setdefault('evolve_r', 0.02)
|
||||||
st.session_state.setdefault('n_trail', 10)
|
st.session_state.setdefault('n_trail', 10)
|
||||||
st.session_state.setdefault('n_epoch', 2)
|
st.session_state.setdefault('n_epoch', 1)
|
||||||
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
st.session_state.setdefault('ckpt_path', list_ckpt_paths(BASE_CKPT_DIR)[0])
|
||||||
st.session_state.setdefault('mode', 2)
|
st.session_state.setdefault('mode', 1)
|
||||||
st.session_state.setdefault('configs', {
|
st.session_state.setdefault('configs', {
|
||||||
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')
|
||||||
})
|
})
|
||||||
@@ -52,6 +83,7 @@ left_col, right_col = st.columns([3, 1])
|
|||||||
|
|
||||||
def train():
|
def train():
|
||||||
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
|
right_col.json({k: st.session_state[k] for k in ('evolve_r', 'n_trail', 'n_epoch', 'ckpt_path', 'mode')})
|
||||||
|
right_col.json(make_train_data(st.session_state.mode))
|
||||||
|
|
||||||
|
|
||||||
def table_update_handler():
|
def table_update_handler():
|
||||||
|
|||||||
Reference in New Issue
Block a user