YeCanming commited on
Commit
4dcd117
·
1 Parent(s): 5bf137e

feat: presentation

Browse files
.streamlit/config.toml CHANGED
@@ -1,17 +1,17 @@
1
  [theme]
2
  base = "dark"
3
  baseFontSize = 15
4
- primaryColor = "#6EA8FE"
5
- backgroundColor = "#0D1117"
6
- secondaryBackgroundColor = "#1A1F2B"
7
- textColor = "#D1D5DB"
8
- linkColor = "#B8C0FF"
9
- borderColor = "#2E3440"
10
  showWidgetBorder = false
11
  baseRadius = "0.3rem"
12
- font = "JetBrains Mono"
13
 
14
  [theme.sidebar]
15
- backgroundColor = "#0A0A0A"
16
- secondaryBackgroundColor = "#1A1A1A"
17
- borderColor = "#2E3440"
 
1
  [theme]
2
  base = "dark"
3
  baseFontSize = 15
4
+ primaryColor = "#1ED760"
5
+ backgroundColor = "#121212"
6
+ secondaryBackgroundColor = "#333333"
7
+ textColor = "#FFFFFF"
8
+ linkColor = "#9D9D9D"
9
+ borderColor = "#7F7F7F"
10
  showWidgetBorder = false
11
  baseRadius = "0.3rem"
12
+ font = "Poppins"
13
 
14
  [theme.sidebar]
15
+ backgroundColor = "#000000"
16
+ secondaryBackgroundColor = "#333333"
17
+ borderColor = "#696969"
.streamlit/theme.toml CHANGED
@@ -1,2 +1,2 @@
1
- theme_name = "柳暗 (Willows Dark) 🌒"
2
- theme_poem = "🌒「深而不死黑,蓝而不夺目,静而不沉闷」柳影婆娑之下,代码悄然生长。"
 
1
+ theme_name = "Spotify"
2
+ theme_poem = ""
src copy/components/metrics_visualizer.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import plotly.graph_objects as go
3
+ import plotly.express as px
4
+ import time
5
+
6
+
7
+ def render_metrics_charts(current_study, current_trial):
8
+ """把原来在 tab_charts 里的 DataFrame 可视化代码都搬进来,只依赖 current_trial 和 st.session_state."""
9
+ if not current_trial.metrics_data:
10
+ st.info("当前 Trial 没有可显示的指标数据。")
11
+ return
12
+
13
+ # 全局步骤控制、st.session_state.shared_selected_global_step 等逻辑照搬
14
+ # …(省略,直接粘进去原来 streamlit_app.py 中的控制器和自动播放部分)…
15
+
16
+ # 然后就是那段循环绘图和 st.metric + st.plotly_chart + st.dataframe
17
+ metric_names = sorted(current_trial.metrics_data.keys())
18
+ cols_per_row = st.slider(
19
+ "每行图表数量",
20
+ 1,
21
+ 4,
22
+ 2,
23
+ key=f"cols_slider_{current_study.name}_{current_trial.name}",
24
+ )
25
+ for i in range(0, len(metric_names), cols_per_row):
26
+ chunk = metric_names[i : i + cols_per_row]
27
+ cols = st.columns(len(chunk))
28
+ for j, m in enumerate(chunk):
29
+ with cols[j]:
30
+ df = current_trial.get_metric_dataframe(m)
31
+ if df is None or df.empty:
32
+ st.warning(f"指标 '{m}' 无数据")
33
+ continue
34
+ st.subheader(m)
35
+ # …Metric 计算 + Plotly 绘制 + 高亮 + 点击同步…
36
+ fig = go.Figure()
37
+ # …省略:完全同原来逻辑…
38
+ st.plotly_chart(fig, use_container_width=True, key=f"chart_{m}")
39
+ st.dataframe(df)
src copy/data_loader.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_loader.py
2
+ from pathlib import Path
3
+ from typing import Dict, List, Any
4
+ import pandas as pd
5
+ import tomli
6
+ import streamlit as st
7
+ from functools import lru_cache # For non-Streamlit specific caching if needed
8
+
9
+ # Assuming utils.py is in the same directory
10
+ from utils import DATA_ROOT_PATH # Used for ensuring directory exists
11
+
12
+ # --- Cache Clearing Functions ---
13
+ # These are more specific cache clearing functions that can be called by model methods.
14
+
15
+
16
+ def clear_study_cache():
17
+ """Clears all study discovery cache."""
18
+ discover_studies_cached.clear()
19
+ st.toast("所有 Study 发现缓存已清除。")
20
+
21
+
22
+ def clear_trial_cache():
23
+ """Clears all trial-related data loading caches."""
24
+ # This is a bit broad. Ideally, clear caches for specific trials/studies.
25
+ load_input_variables_from_path.clear()
26
+ load_all_metrics_for_trial_path.clear()
27
+ discover_trials_from_path.clear()
28
+ st.toast("所有 Trial 数据加载缓存已清除。")
29
+
30
+
31
+ def clear_specific_trial_metric_cache(trial_path: Path):
32
+ load_all_metrics_for_trial_path.clear() # This clears the whole cache for this func
33
+ # For more granular control with @st.cache_data, you'd typically rely on Streamlit's
34
+ # automatic cache invalidation based on input args, or rerun.
35
+ # If using lru_cache, you could do: load_all_metrics_for_trial_path.cache_clear()
36
+ # but for st.cache_data, clearing for specific args is not direct.
37
+ # The common pattern is to clear the entire function's cache.
38
+ st.toast(f"Trial '{trial_path.name}' 的指标缓存已清除 (函数级别)。")
39
+
40
+
41
+ def clear_specific_trial_input_vars_cache(trial_path: Path):
42
+ load_input_variables_from_path.clear()
43
+ st.toast(f"Trial '{trial_path.name}' 的参数缓存已清除 (函数级别)。")
44
+
45
+
46
+ def clear_specific_study_trial_discovery_cache(study_path: Path):
47
+ discover_trials_from_path.clear()
48
+ st.toast(f"Study '{study_path.name}' 的 Trial 发现缓存已清除 (函数级别)。")
49
+
50
+
51
+ # --- Data Discovery and Loading Functions (Cached) ---
52
+
53
+
54
+ def ensure_data_directory_exists(data_path: Path = DATA_ROOT_PATH):
55
+ """Ensures the root data directory exists."""
56
+ if not data_path.exists():
57
+ try:
58
+ data_path.mkdir(parents=True, exist_ok=True)
59
+ st.info(f"数据目录 {data_path} 已创建。")
60
+ except Exception as e:
61
+ st.error(f"创建数据目录 {data_path} 失败: {e}")
62
+ st.stop()
63
+ elif not data_path.is_dir():
64
+ st.error(f"路径 {data_path} 已存在但不是一个目录。")
65
+ st.stop()
66
+
67
+
68
+ @st.cache_data(ttl=3600) # Cache for 1 hour, or adjust as needed
69
+ def discover_studies_cached(
70
+ _data_root: Path,
71
+ ) -> Dict[
72
+ str, Any
73
+ ]: # Return type hint as Any to avoid circular dep with data_models.Study
74
+ """
75
+ Scans the data_root for study directories and returns a dictionary
76
+ mapping study names to Study objects (or just their paths initially).
77
+ The actual Study object creation happens in the main app for now.
78
+ """
79
+ # To avoid issues with caching complex objects directly, or circular dependencies,
80
+ # this function can return simpler structures like Dict[str, Path]
81
+ # and the main app or model can instantiate Study objects.
82
+ # For this iteration, we'll import Study here for convenience, assuming careful structure.
83
+ from data_models import (
84
+ Study,
85
+ ) # Local import to help with potential circularity if models grow complex
86
+
87
+ if not _data_root.is_dir():
88
+ return {}
89
+ studies = {}
90
+ for d in _data_root.iterdir():
91
+ if d.is_dir():
92
+ studies[d.name] = Study(name=d.name, path=d)
93
+ return studies
94
+
95
+
96
+ @st.cache_data(ttl=3600)
97
+ def discover_trials_from_path(_study_path: Path) -> Dict[str, Path]:
98
+ """Scans a study_path for trial directories."""
99
+ if not _study_path.is_dir():
100
+ return {}
101
+ trials = {}
102
+ for d in _study_path.iterdir():
103
+ if d.is_dir():
104
+ trials[d.name] = d
105
+ return trials
106
+
107
+
108
+ @st.cache_data(ttl=3600)
109
+ def load_input_variables_from_path(_trial_path: Path) -> Dict[str, Any]:
110
+ """Loads input_variables.toml from a trial directory."""
111
+ input_vars_file = _trial_path / "input_variables.toml"
112
+ if input_vars_file.exists():
113
+ try:
114
+ with open(input_vars_file, "rb") as f:
115
+ return tomli.load(f)
116
+ except tomli.TOMLDecodeError:
117
+ # st.error(f"错误:无法解析 input_variables.toml 文件于 {_trial_path}") # Avoid st.error in cached funcs if possible
118
+ print(f"Error parsing input_variables.toml at {_trial_path}")
119
+ return {}
120
+ return {}
121
+
122
+
123
+ def _load_single_metric_toml(_toml_file_path: Path) -> pd.DataFrame:
124
+ """Loads metrics from a single TOML file into a DataFrame."""
125
+ if not _toml_file_path.exists():
126
+ return pd.DataFrame()
127
+ try:
128
+ with open(_toml_file_path, "rb") as f:
129
+ data = tomli.load(f)
130
+ metrics_list = data.get("metrics", [])
131
+ if not metrics_list:
132
+ return pd.DataFrame()
133
+ return pd.DataFrame(metrics_list)
134
+ except tomli.TOMLDecodeError:
135
+ print(f"Error parsing TOML file: {_toml_file_path.name}")
136
+ return pd.DataFrame()
137
+ except Exception as e:
138
+ print(f"Error loading {_toml_file_path.name}: {e}")
139
+ return pd.DataFrame()
140
+
141
+
142
+ @st.cache_data(ttl=300) # Cache metric data for 5 minutes
143
+ def load_all_metrics_for_trial_path(_trial_path: Path) -> Dict[str, pd.DataFrame]:
144
+ """
145
+ Loads all metrics from all tracks in a trial.
146
+ Returns a dictionary where keys are metric names (e.g., 'loss', 'accuracy')
147
+ and values are DataFrames containing 'global_step', 'value', and 'track'.
148
+ """
149
+ scalar_dir = _trial_path / "logs" / "scalar"
150
+ if not scalar_dir.is_dir():
151
+ return {}
152
+
153
+ all_metrics_data_combined: Dict[str, pd.DataFrame] = {}
154
+
155
+ for toml_file in scalar_dir.glob("metrics_*.toml"):
156
+ track_name = toml_file.stem.replace("metrics_", "")
157
+ df_track = _load_single_metric_toml(toml_file)
158
+
159
+ if df_track.empty or "global_step" not in df_track.columns:
160
+ continue
161
+
162
+ id_vars = ["global_step"]
163
+ value_vars = [col for col in df_track.columns if col not in id_vars]
164
+
165
+ if not value_vars:
166
+ continue
167
+
168
+ # Process each metric column individually to build up the combined DataFrame
169
+ for metric_col_name in value_vars:
170
+ try:
171
+ # Create a DataFrame for the current metric and track
172
+ current_metric_df = df_track[["global_step", metric_col_name]].copy()
173
+ current_metric_df.rename(
174
+ columns={metric_col_name: "value"}, inplace=True
175
+ )
176
+ current_metric_df["track"] = track_name
177
+ current_metric_df["value"] = pd.to_numeric(
178
+ current_metric_df["value"], errors="coerce"
179
+ )
180
+ current_metric_df.dropna(subset=["value"], inplace=True)
181
+
182
+ if current_metric_df.empty:
183
+ continue
184
+
185
+ # Append to the combined DataFrame for this metric_col_name
186
+ if metric_col_name not in all_metrics_data_combined:
187
+ all_metrics_data_combined[metric_col_name] = current_metric_df
188
+ else:
189
+ all_metrics_data_combined[metric_col_name] = pd.concat(
190
+ [all_metrics_data_combined[metric_col_name], current_metric_df],
191
+ ignore_index=True,
192
+ )
193
+ except Exception as e:
194
+ print(
195
+ f"Error processing metric '{metric_col_name}' from file '{toml_file.name}': {e}"
196
+ )
197
+ continue
198
+
199
+ # Sort data by global_step for proper line plotting
200
+ for metric_name in all_metrics_data_combined:
201
+ all_metrics_data_combined[metric_name] = (
202
+ all_metrics_data_combined[metric_name]
203
+ .sort_values(by=["track", "global_step"])
204
+ .reset_index(drop=True)
205
+ )
206
+
207
+ return all_metrics_data_combined
src copy/data_models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # data_models.py
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import Dict, List, Optional, Any
5
+ import pandas as pd
6
+ import streamlit as st # For caching
7
+
8
+ # Import from data_loader, assuming it's in the same directory
9
+ # We'll define these functions in data_loader.py
10
+ # To avoid circular imports, data_loader functions won't import data_models directly for type hints if possible,
11
+ # or use string type hints.
12
+
13
+ # Forward declaration for type hint if data_loader needs Study/Trial
14
+ # class Study: pass
15
+ # class Trial: pass
16
+
17
+ from data_loader import (
18
+ load_input_variables_from_path,
19
+ load_all_metrics_for_trial_path,
20
+ discover_trials_from_path,
21
+ clear_trial_cache as clear_trial_loader_cache,
22
+ clear_study_cache as clear_study_loader_cache,
23
+ clear_specific_trial_metric_cache,
24
+ clear_specific_trial_input_vars_cache,
25
+ clear_specific_study_trial_discovery_cache,
26
+ )
27
+
28
+
29
+ @dataclass
30
+ class Trial:
31
+ name: str
32
+ path: Path
33
+ study_name: str # To know its parent study
34
+ input_variables: Dict[str, Any] = field(default_factory=dict, repr=False)
35
+ metrics_data: Dict[str, pd.DataFrame] = field(
36
+ default_factory=dict, repr=False
37
+ ) # Key: metric_name, Value: DataFrame with global_step, value, track
38
+
39
+ def __post_init__(self):
40
+ # Automatically load data if needed, but prefer explicit calls from UI for clarity
41
+ pass
42
+
43
+ # Use st.cache_data on the loader functions, not directly here for complex objects.
44
+ # Instead, methods here will call cached loader functions.
45
+
46
+ def load_input_variables_cached(self):
47
+ """Loads or retrieves cached input variables."""
48
+ if not self.input_variables: # Load only if not already populated
49
+ self.input_variables = load_input_variables_from_path(self.path)
50
+ return self.input_variables
51
+
52
+ def load_metrics_cached(self):
53
+ """Loads or retrieves cached metrics data."""
54
+ if not self.metrics_data: # Load only if not already populated
55
+ self.metrics_data = load_all_metrics_for_trial_path(self.path)
56
+ return self.metrics_data
57
+
58
+ def get_metric_dataframe(self, metric_name: str) -> Optional[pd.DataFrame]:
59
+ """Returns the DataFrame for a specific metric, combining all tracks."""
60
+ return self.metrics_data.get(metric_name)
61
+
62
+ def clear_cache(self):
63
+ """Clears cached data for this specific trial."""
64
+ # Clear Streamlit's cache for functions related to this trial
65
+ clear_specific_trial_metric_cache(self.path)
66
+ clear_specific_trial_input_vars_cache(self.path)
67
+ # Reset instance variables
68
+ self.input_variables = {}
69
+ self.metrics_data = {}
70
+ st.success(f"Trial '{self.name}' 的缓存已清除。")
71
+
72
+
73
+ @dataclass
74
+ class Study:
75
+ name: str
76
+ path: Path
77
+ trials: Dict[str, Trial] = field(default_factory=dict, repr=False)
78
+
79
+ def discover_trials_cached(self):
80
+ """Discovers or retrieves cached trials for this study."""
81
+ if not self.trials: # Discover only if not already populated
82
+ trial_paths = discover_trials_from_path(
83
+ self.path
84
+ ) # This loader function should be cached
85
+ for trial_name, trial_path in trial_paths.items():
86
+ self.trials[trial_name] = Trial(
87
+ name=trial_name, path=trial_path, study_name=self.name
88
+ )
89
+ return self.trials
90
+
91
+ def get_trial(self, trial_name: str) -> Optional[Trial]:
92
+ return self.trials.get(trial_name)
93
+
94
+ def clear_cache(self):
95
+ """Clears cached data for this study and its trials."""
96
+ clear_specific_study_trial_discovery_cache(self.path)
97
+ for trial in self.trials.values():
98
+ trial.clear_cache() # Clear cache for each trial within the study
99
+ self.trials = {} # Reset trials dictionary
100
+ st.success(f"Study '{self.name}' 及其 Trials 的缓存已清除。")
src copy/streamlit_app.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ # --- Page Configuration ---
4
+ st.set_page_config(
5
+ layout="wide",
6
+ page_title="柳暗花明 (flowillower)",
7
+ page_icon=":sunrise_over_mountains:",
8
+ initial_sidebar_state="expanded",
9
+ )
10
+
11
+ from pathlib import Path
12
+ import plotly.graph_objects as go
13
+ import plotly.express as px
14
+ from plotly.subplots import make_subplots
15
+ import pandas as pd
16
+ import time
17
+
18
+ # --- Logo ---
19
+ st.logo("logo.png", icon_image="logo.png")
20
+
21
+ # 导入重构后的模块
22
+ try:
23
+ from utils import DATA_ROOT_PATH, AppMode
24
+ from data_models import Study, Trial # Study, Trial will be used
25
+ from data_loader import discover_studies_cached, ensure_data_directory_exists
26
+ from theme_selector import render_theme_selector # 新增:导入主题选择器
27
+ except ImportError as e:
28
+ st.error(
29
+ f"导入模块失败,请确保 utils.py, data_models.py, data_loader.py, theme_selector.py 文件存在于正确的位置: {e}"
30
+ )
31
+ st.stop()
32
+
33
+
34
+ # --- 应用状态管理 ---
35
+ if "selected_study_name" not in st.session_state:
36
+ st.session_state.selected_study_name = None
37
+ if "selected_trial_name" not in st.session_state:
38
+ st.session_state.selected_trial_name = None
39
+ # if "studies_data" not in st.session_state: # Not directly used, discover_studies_cached returns objects
40
+ # st.session_state.studies_data = {}
41
+ if "app_mode" not in st.session_state:
42
+ st.session_state.app_mode = AppMode.VIEWING
43
+
44
+ # 新增: 用于跨图表共享选中的 global_step
45
+ if "shared_selected_global_step" not in st.session_state:
46
+ st.session_state.shared_selected_global_step = None
47
+
48
+ # 新增: 自动播放相关状态
49
+ if "is_auto_playing" not in st.session_state:
50
+ st.session_state.is_auto_playing = False
51
+ if "auto_play_speed" not in st.session_state:
52
+ st.session_state.auto_play_speed = 1.0
53
+ if "auto_play_needs_rerun" not in st.session_state:
54
+ st.session_state.auto_play_needs_rerun = False
55
+
56
+
57
+ # --- UI Rendering ---
58
+
59
+ # --- Header ---
60
+ header_cols = st.columns([2, 3, 1.5, 0.5, 0.5, 0.5, 1]) # 新增一列用于主题选择器
61
+ with header_cols[0]:
62
+ st.markdown("## 柳暗花明")
63
+ st.caption("flowillower")
64
+
65
+ ensure_data_directory_exists(DATA_ROOT_PATH)
66
+ all_study_objects = discover_studies_cached(DATA_ROOT_PATH)
67
+ study_names = list(all_study_objects.keys())
68
+
69
+ if not study_names:
70
+ st.warning(
71
+ f"在 {DATA_ROOT_PATH} 未找到任何 Study。请确保您的数据结构正确或使用 flowillower API 开始记录实验。"
72
+ )
73
+
74
+ if study_names:
75
+ with header_cols[1]:
76
+ if st.session_state.selected_study_name not in study_names:
77
+ st.session_state.selected_study_name = (
78
+ study_names[0] if study_names else None
79
+ )
80
+
81
+ selected_study_name_from_ui = st.selectbox(
82
+ "选择 Study (Select Study)",
83
+ study_names,
84
+ index=study_names.index(st.session_state.selected_study_name)
85
+ if st.session_state.selected_study_name in study_names
86
+ else 0,
87
+ label_visibility="collapsed",
88
+ key="study_selector_main_ui",
89
+ )
90
+ if selected_study_name_from_ui != st.session_state.selected_study_name:
91
+ st.session_state.selected_study_name = selected_study_name_from_ui
92
+ st.session_state.selected_trial_name = None
93
+ st.session_state.shared_selected_global_step = None # Study 变化时清除高亮
94
+ st.rerun()
95
+
96
+ with header_cols[2]:
97
+ if st.session_state.selected_study_name:
98
+ st.write(f"当前 Study: **{st.session_state.selected_study_name}**")
99
+ else:
100
+ with header_cols[1]:
101
+ st.info("没有可用的 Study。")
102
+
103
+ with header_cols[3]:
104
+ st.button("➕", help="添加 (Add)", disabled=True)
105
+ with header_cols[4]:
106
+ st.button("⚙️", help="设置 (Settings)", disabled=True)
107
+ with header_cols[5]:
108
+ st.button("👤", help="用户 (User)", disabled=True)
109
+ with header_cols[6]: # 新增:主题选择器列
110
+ with st.container():
111
+ # st.markdown("**主题**")
112
+ render_theme_selector()
113
+ st.markdown("---")
114
+
115
+ # --- Sidebar ---
116
+ current_study: Study | None = None
117
+ if (
118
+ st.session_state.selected_study_name
119
+ and st.session_state.selected_study_name in all_study_objects
120
+ ):
121
+ current_study = all_study_objects[st.session_state.selected_study_name]
122
+ if not current_study.trials:
123
+ current_study.discover_trials_cached()
124
+
125
+ trial_names = list(current_study.trials.keys()) if current_study else []
126
+
127
+ with st.sidebar:
128
+ st.markdown("### Study")
129
+ if current_study:
130
+ st.markdown(f"##### {current_study.name}")
131
+ if st.button("刷新 Study 数据 (Refresh Study Data)", use_container_width=True):
132
+ current_study.clear_cache()
133
+ st.rerun()
134
+ if st.button("概览 (Overview)", use_container_width=True, disabled=True):
135
+ st.toast("功能待实现")
136
+ if st.button(
137
+ "图表对比视图 (Chart Comparison View)",
138
+ use_container_width=True,
139
+ disabled=True,
140
+ ):
141
+ st.toast("功能待实现")
142
+ else:
143
+ st.markdown("未选择 Study")
144
+
145
+ st.markdown("---")
146
+ st.markdown("### Trial")
147
+ if current_study and trial_names:
148
+ if st.session_state.selected_trial_name not in trial_names:
149
+ st.session_state.selected_trial_name = (
150
+ trial_names[0] if trial_names else None
151
+ )
152
+
153
+ selected_trial_name_from_ui = st.radio(
154
+ "选择 Trial (Select Trial)",
155
+ trial_names,
156
+ index=trial_names.index(st.session_state.selected_trial_name)
157
+ if st.session_state.selected_trial_name in trial_names
158
+ else 0,
159
+ label_visibility="collapsed",
160
+ key="trial_selector_sidebar_ui",
161
+ )
162
+ if selected_trial_name_from_ui != st.session_state.selected_trial_name:
163
+ st.session_state.selected_trial_name = selected_trial_name_from_ui
164
+ st.session_state.shared_selected_global_step = None # Trial 变化时清除高亮
165
+ st.rerun()
166
+ if st.session_state.selected_trial_name:
167
+ st.markdown(f"当前选择: **{st.session_state.selected_trial_name}**")
168
+ elif current_study:
169
+ st.info(f"Study '{current_study.name}' 中没有 Trial。")
170
+ else:
171
+ st.info("请先选择一个 Study。")
172
+ st.markdown("---")
173
+ if st.button("⚙️ App 设置 (App Settings)", use_container_width=True, disabled=True):
174
+ st.toast("功能待实现")
175
+
176
+ # --- Main Content Area ---
177
+ current_trial: Trial | None = None
178
+ if (
179
+ current_study
180
+ and st.session_state.selected_trial_name
181
+ and st.session_state.selected_trial_name in current_study.trials
182
+ ):
183
+ current_trial = current_study.trials[st.session_state.selected_trial_name]
184
+ current_trial.load_input_variables_cached()
185
+ current_trial.load_metrics_cached()
186
+
187
+ if current_study and current_trial:
188
+ main_title_cols = st.columns([3, 1, 0.5])
189
+ with main_title_cols[0]:
190
+ st.markdown(f"## {current_trial.name}")
191
+ st.caption(f"属于 Study: {current_study.name}")
192
+ with main_title_cols[1]:
193
+ if st.button("刷新 Trial 数据 (Refresh Trial Data)", type="secondary"):
194
+ current_trial.clear_cache()
195
+ st.rerun()
196
+ with main_title_cols[2]:
197
+ st.button("...", help="更多选项 (More Options)", disabled=True)
198
+
199
+ # 添加全局步骤控制器
200
+ if current_trial.metrics_data:
201
+ st.markdown("### 全局步骤控制 (Global Step Control)")
202
+
203
+ # 获取所有指标的全局步骤范围
204
+ all_global_steps = set()
205
+ for metric_name in current_trial.metrics_data.keys():
206
+ df_metric = current_trial.get_metric_dataframe(metric_name)
207
+ if (
208
+ df_metric is not None
209
+ and not df_metric.empty
210
+ and "global_step" in df_metric.columns
211
+ ):
212
+ all_global_steps.update(df_metric["global_step"].tolist())
213
+
214
+ if all_global_steps:
215
+ all_global_steps = sorted(list(all_global_steps))
216
+ min_step, max_step = min(all_global_steps), max(all_global_steps)
217
+
218
+ # 控制器布局
219
+ control_cols = st.columns([3, 1, 1, 1])
220
+
221
+ with control_cols[0]:
222
+ # 滑动条
223
+ if st.session_state.shared_selected_global_step is None:
224
+ # 默认选择最后一个step
225
+ st.session_state.shared_selected_global_step = max_step
226
+
227
+ # 确保当前选中的步骤在有效范围内
228
+ if st.session_state.shared_selected_global_step not in all_global_steps:
229
+ # 找到最接近的有效步骤
230
+ closest_step = min(
231
+ all_global_steps,
232
+ key=lambda x: abs(
233
+ x - st.session_state.shared_selected_global_step
234
+ ),
235
+ )
236
+ st.session_state.shared_selected_global_step = closest_step
237
+
238
+ selected_step = st.select_slider(
239
+ "选择全局步骤",
240
+ options=all_global_steps,
241
+ value=st.session_state.shared_selected_global_step,
242
+ format_func=lambda x: f"Step {x}",
243
+ key="global_step_slider",
244
+ )
245
+
246
+ if selected_step != st.session_state.shared_selected_global_step:
247
+ st.session_state.shared_selected_global_step = selected_step
248
+ st.rerun()
249
+
250
+ with control_cols[1]:
251
+ # 播放/暂停按钮
252
+ if st.session_state.is_auto_playing:
253
+ if st.button("⏸️ 暂停", type="primary", use_container_width=True):
254
+ st.session_state.is_auto_playing = False
255
+ st.rerun()
256
+ else:
257
+ if st.button("▶️ 播放", type="primary", use_container_width=True):
258
+ st.session_state.is_auto_playing = True
259
+ st.rerun()
260
+
261
+ with control_cols[2]:
262
+ # 速度控制
263
+ speed = st.selectbox(
264
+ "播放速度",
265
+ options=[0.5, 1.0, 2.0, 4.0],
266
+ index=[0.5, 1.0, 2.0, 4.0].index(st.session_state.auto_play_speed),
267
+ format_func=lambda x: f"{x}x",
268
+ key="speed_selector",
269
+ )
270
+ if speed != st.session_state.auto_play_speed:
271
+ st.session_state.auto_play_speed = speed
272
+
273
+ with control_cols[3]:
274
+ # 重置按钮
275
+ if st.button("🔄 重置", use_container_width=True):
276
+ st.session_state.shared_selected_global_step = min_step
277
+ st.session_state.is_auto_playing = False
278
+ st.rerun()
279
+
280
+ # 自动播放逻辑 - 设置标志但不立即rerun
281
+ if st.session_state.is_auto_playing:
282
+ current_index = all_global_steps.index(
283
+ st.session_state.shared_selected_global_step
284
+ )
285
+ if current_index < len(all_global_steps) - 1:
286
+ # 等待指定时间后移动到下一步
287
+ time.sleep(1.0 / st.session_state.auto_play_speed)
288
+ st.session_state.shared_selected_global_step = all_global_steps[
289
+ current_index + 1
290
+ ]
291
+ st.session_state.auto_play_needs_rerun = True
292
+ else:
293
+ # 到达末尾,停止播放
294
+ st.session_state.is_auto_playing = False
295
+ st.session_state.auto_play_needs_rerun = True
296
+
297
+ # 显示当前步骤信息
298
+ st.info(
299
+ f"当前选中步骤: **{st.session_state.shared_selected_global_step}** / {max_step}"
300
+ )
301
+
302
+ st.markdown("---")
303
+
304
+ tab_titles = [
305
+ "图表 (Charts)",
306
+ "参数 (Parameters)",
307
+ "系统 (System)",
308
+ "日志 (Logs)",
309
+ "环境 (Environment)",
310
+ ]
311
+ tab_charts, tab_params, tab_system, tab_logs, tab_env = st.tabs(tab_titles)
312
+
313
+ with tab_charts:
314
+ st.header("指标图表 (Metrics Charts)")
315
+ st.markdown("---")
316
+
317
+ if not current_trial.metrics_data:
318
+ st.info("当前 Trial 没有可显示的指标数据。")
319
+ else:
320
+ num_metrics = len(current_trial.metrics_data)
321
+ cols_per_row = st.slider(
322
+ "每行图表数量 (Charts per row)",
323
+ 1,
324
+ 4,
325
+ min(2, num_metrics) if num_metrics > 0 else 1,
326
+ key=f"cols_slider_{current_study.name}_{current_trial.name}",
327
+ )
328
+ metric_names = sorted(list(current_trial.metrics_data.keys()))
329
+
330
+ for i in range(0, num_metrics, cols_per_row):
331
+ metric_chunk = metric_names[i : i + cols_per_row]
332
+ chart_cols = st.columns(cols_per_row)
333
+ for j, metric_name in enumerate(metric_chunk):
334
+ with chart_cols[j]:
335
+ df_metric = current_trial.get_metric_dataframe(metric_name)
336
+ if df_metric is None or df_metric.empty:
337
+ st.warning(f"指标 '{metric_name}' 数据不完整或缺失。")
338
+ continue
339
+
340
+ with st.container(border=True):
341
+ st.subheader(metric_name)
342
+
343
+ # 添加metric组件 - 显示当前值和增量
344
+ try:
345
+ current_step = (
346
+ st.session_state.shared_selected_global_step
347
+ )
348
+
349
+ # 获取所有可能的track
350
+ all_tracks = (
351
+ df_metric["track"].unique()
352
+ if "track" in df_metric.columns
353
+ else [None]
354
+ )
355
+
356
+ # 为每个track创建metric组件
357
+ if len(all_tracks) > 1:
358
+ metric_cols = st.columns(len(all_tracks))
359
+ else:
360
+ metric_cols = [st] # 使用整个容器
361
+
362
+ for idx, track in enumerate(all_tracks):
363
+ # 查找当前步骤的数据
364
+ if track is not None:
365
+ current_step_data = df_metric[
366
+ (df_metric["global_step"] == current_step)
367
+ & (df_metric["track"] == track)
368
+ ]
369
+ else:
370
+ current_step_data = df_metric[
371
+ df_metric["global_step"] == current_step
372
+ ]
373
+
374
+ current_value = None
375
+ delta_value = None
376
+
377
+ # 如果当前步骤没有该track的数据,向前查找最近的步骤
378
+ if current_step_data.empty:
379
+ # 向前查找最近的有该track数据的步骤
380
+ current_index = all_global_steps.index(
381
+ current_step
382
+ )
383
+ for search_idx in range(
384
+ current_index - 1, -1, -1
385
+ ):
386
+ search_step = all_global_steps[search_idx]
387
+ if track is not None:
388
+ search_data = df_metric[
389
+ (
390
+ df_metric["global_step"]
391
+ == search_step
392
+ )
393
+ & (df_metric["track"] == track)
394
+ ]
395
+ else:
396
+ search_data = df_metric[
397
+ df_metric["global_step"]
398
+ == search_step
399
+ ]
400
+
401
+ if not search_data.empty:
402
+ current_value = search_data[
403
+ "value"
404
+ ].iloc[0]
405
+ current_step_found = search_step
406
+ break
407
+ else:
408
+ current_value = current_step_data["value"].iloc[
409
+ 0
410
+ ]
411
+ current_step_found = current_step
412
+
413
+ # 计算增量:查找比当前找到的步骤更早的数据
414
+ if current_value is not None:
415
+ current_found_index = all_global_steps.index(
416
+ current_step_found
417
+ )
418
+ for prev_idx in range(
419
+ current_found_index - 1, -1, -1
420
+ ):
421
+ prev_step = all_global_steps[prev_idx]
422
+ if track is not None:
423
+ prev_step_data = df_metric[
424
+ (
425
+ df_metric["global_step"]
426
+ == prev_step
427
+ )
428
+ & (df_metric["track"] == track)
429
+ ]
430
+ else:
431
+ prev_step_data = df_metric[
432
+ df_metric["global_step"]
433
+ == prev_step
434
+ ]
435
+
436
+ if not prev_step_data.empty:
437
+ prev_value = prev_step_data[
438
+ "value"
439
+ ].iloc[0]
440
+ delta_value = current_value - prev_value
441
+ break
442
+
443
+ # 显示metric组件
444
+ with (
445
+ metric_cols[idx]
446
+ if len(all_tracks) > 1
447
+ else metric_cols[0]
448
+ ):
449
+ if current_value is not None:
450
+ # 确定label
451
+ if track is not None:
452
+ if current_step_found != current_step:
453
+ label = f"{track} (Step {current_step_found})"
454
+ else:
455
+ label = f"{track}"
456
+ else:
457
+ if current_step_found != current_step:
458
+ label = f"当前值 (Step {current_step_found})"
459
+ else:
460
+ label = (
461
+ f"当前值 (Step {current_step})"
462
+ )
463
+
464
+ st.metric(
465
+ label=label,
466
+ value=f"{current_value:.4f}",
467
+ delta=f"{delta_value:.4f}"
468
+ if delta_value is not None
469
+ else None,
470
+ )
471
+ else:
472
+ # 没有找到任何数据
473
+ track_label = (
474
+ track if track is not None else "数据"
475
+ )
476
+ st.metric(
477
+ label=f"{track_label}",
478
+ value="无数据",
479
+ delta=None,
480
+ )
481
+
482
+ except Exception as e:
483
+ st.warning(f"计算指标值时出错: {e}")
484
+
485
+ try:
486
+ # 创建 Plotly 图表
487
+ fig = go.Figure()
488
+
489
+ # 按track分组绘制线条
490
+ if "track" in df_metric.columns:
491
+ tracks = df_metric["track"].unique()
492
+ colors = px.colors.qualitative.Set1[: len(tracks)]
493
+
494
+ for k, track in enumerate(tracks):
495
+ track_data = df_metric[
496
+ df_metric["track"] == track
497
+ ]
498
+ fig.add_trace(
499
+ go.Scatter(
500
+ x=track_data["global_step"],
501
+ y=track_data["value"],
502
+ mode="lines+markers",
503
+ name=track,
504
+ line=dict(
505
+ color=colors[k % len(colors)]
506
+ ),
507
+ marker=dict(
508
+ size=6,
509
+ color=colors[k % len(colors)],
510
+ line=dict(width=1, color="white"),
511
+ ),
512
+ customdata=track_data[
513
+ ["global_step", "value", "track"]
514
+ ],
515
+ hovertemplate="<b>%{fullData.name}</b><br>"
516
+ + "Global Step: %{x}<br>"
517
+ + "Value: %{y}<br>"
518
+ + "<extra></extra>",
519
+ )
520
+ )
521
+ else:
522
+ # 如果没有track列,绘制单条线
523
+ fig.add_trace(
524
+ go.Scatter(
525
+ x=df_metric["global_step"],
526
+ y=df_metric["value"],
527
+ mode="lines+markers",
528
+ name=metric_name,
529
+ marker=dict(
530
+ size=6,
531
+ line=dict(width=1, color="white"),
532
+ ),
533
+ customdata=df_metric[
534
+ ["global_step", "value"]
535
+ ],
536
+ hovertemplate="Global Step: %{x}<br>"
537
+ + "Value: %{y}<br>"
538
+ + "<extra></extra>",
539
+ )
540
+ )
541
+
542
+ # 如果有共享的选中步骤,添加高亮线
543
+ if (
544
+ st.session_state.shared_selected_global_step
545
+ is not None
546
+ ):
547
+ fig.add_vline(
548
+ x=st.session_state.shared_selected_global_step,
549
+ line_width=2,
550
+ line_dash="solid",
551
+ line_color="firebrick",
552
+ opacity=0.9,
553
+ )
554
+
555
+ # 设置图表布局
556
+ fig.update_layout(
557
+ title=None,
558
+ xaxis_title="全局步骤 (Global Step)",
559
+ yaxis_title=metric_name,
560
+ height=400,
561
+ margin=dict(l=0, r=0, t=0, b=0),
562
+ showlegend=True
563
+ if "track" in df_metric.columns
564
+ and len(df_metric["track"].unique()) > 1
565
+ else False,
566
+ hovermode="closest",
567
+ )
568
+
569
+ # 显示图表并处理点击事件
570
+ chart_key = f"chart_metric_{current_study.name}_{current_trial.name}_{metric_name}"
571
+ clicked_points = st.plotly_chart(
572
+ fig,
573
+ use_container_width=True,
574
+ key=chart_key,
575
+ on_select="rerun",
576
+ )
577
+
578
+ # 处理点击事件
579
+ if clicked_points and "selection" in clicked_points:
580
+ selection = clicked_points["selection"]
581
+ if (
582
+ "points" in selection
583
+ and len(selection["points"]) > 0
584
+ ):
585
+ # 获取第一个点击点的 x 坐标 (global_step)
586
+ clicked_x = selection["points"][0]["x"]
587
+ if clicked_x is not None:
588
+ new_step = int(clicked_x)
589
+ if (
590
+ st.session_state.get(
591
+ "shared_selected_global_step"
592
+ )
593
+ != new_step
594
+ ):
595
+ st.session_state.shared_selected_global_step = new_step
596
+ # 点击图表时停止自动播放
597
+ st.session_state.is_auto_playing = False
598
+ st.rerun()
599
+
600
+ except Exception as e:
601
+ st.error(f"为指标 '{metric_name}' 生成图表时出错: {e}")
602
+ st.dataframe(df_metric)
603
+ # raise e
604
+
605
+ with tab_params:
606
+ st.header("输入参数 (Input Parameters)")
607
+ if current_trial.input_variables:
608
+ st.json(current_trial.input_variables)
609
+ else:
610
+ st.info("未找到 `input_variables.toml` 或文件为空。")
611
+
612
+ for tab_content, name in [
613
+ (tab_system, "系统监控"),
614
+ (tab_logs, "日志"),
615
+ (tab_env, "环境"),
616
+ ]:
617
+ with tab_content:
618
+ st.header(name)
619
+ st.info("此功能待您的 `flowillower` API 提供相关数据后实现。")
620
+
621
+ elif not st.session_state.selected_study_name:
622
+ st.info("👈 请从顶部选择一个 Study 开始。")
623
+ elif not st.session_state.selected_trial_name:
624
+ st.info("👈 请从侧边栏选择一个 Trial。")
625
+ else:
626
+ st.info("请选择 Study 和 Trial 以查看数据。")
627
+
628
+ st.markdown("---")
629
+ st.caption("柳暗花明 (flowillower) - 数据可视化App")
630
+
631
+ # 在页面最后处理自动播放的rerun
632
+ if st.session_state.get("auto_play_needs_rerun", False):
633
+ st.session_state.auto_play_needs_rerun = False
634
+ st.rerun()
src copy/theme_selector.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import toml
3
+ from pathlib import Path
4
+ import os
5
+ import time
6
+
7
+
8
+ class ThemeSelector:
9
+ def __init__(
10
+ self, themes_dir=".streamlit/themes", config_path=".streamlit/config.toml"
11
+ ):
12
+ self.themes_dir = Path(themes_dir)
13
+ self.config_path = Path(config_path)
14
+ self.themes = {}
15
+ self.load_themes()
16
+
17
+ def load_themes(self):
18
+ """加载所有主题文件"""
19
+ self.themes = {}
20
+ if not self.themes_dir.exists():
21
+ return
22
+
23
+ for theme_file in self.themes_dir.glob("*.toml"):
24
+ try:
25
+ theme_data = toml.load(theme_file)
26
+
27
+ # 从根级别获取theme_name和theme_poem
28
+ theme_name = theme_data.get("theme_name", theme_file.stem)
29
+ theme_poem = theme_data.get("theme_poem", "")
30
+ theme_config = theme_data.get("theme", {})
31
+
32
+ self.themes[theme_name] = {
33
+ "file": theme_file,
34
+ "name": theme_name,
35
+ "poem": theme_poem,
36
+ "config": theme_config,
37
+ }
38
+ except Exception as e:
39
+ st.warning(f"读取主题文件 {theme_file} 失败: {e}")
40
+
41
+ def get_current_theme(self):
42
+ """获取当前主题名称"""
43
+ if not self.config_path.exists():
44
+ return None
45
+
46
+ try:
47
+ # config = toml.load(self.config_path)
48
+ # # 从根级别读取theme_name
49
+ # current_theme_name = config.get("theme") or {}
50
+ # current_theme_name = current_theme_name.get("theme_name")
51
+ # return current_theme_name
52
+ theme_toml = self.config_path.parent / "theme.toml"
53
+ theme = toml.load(theme_toml)
54
+ return theme.get("theme_name")
55
+
56
+ except Exception:
57
+ return None
58
+
59
+ def apply_theme(self, theme_name):
60
+ """应用选定的主题"""
61
+ if theme_name not in self.themes:
62
+ st.error(f"主题 '{theme_name}' 不存在")
63
+ return False
64
+
65
+ try:
66
+ # 确保配置目录存在
67
+ self.config_path.parent.mkdir(parents=True, exist_ok=True)
68
+
69
+ # 读取现有配置或创建新配置
70
+ config = {}
71
+ if self.config_path.exists():
72
+ try:
73
+ config = toml.load(self.config_path)
74
+ except Exception:
75
+ config = {}
76
+
77
+ # 添加根级别的theme_name和theme_poem
78
+ # config["theme_name"] = self.themes[theme_name]["name"]
79
+ # config["theme_poem"] = self.themes[theme_name]["poem"]
80
+
81
+ # 更新主题配置
82
+ theme_config = self.themes[theme_name]["config"].copy()
83
+ # theme_config["theme_name"] = self.themes[theme_name]["name"]
84
+ # theme_config["theme_poem"] = self.themes[theme_name]["poem"]
85
+ config["theme"] |= theme_config
86
+
87
+ # 写入配置文件
88
+ with open(self.config_path, "w", encoding="utf-8") as f:
89
+ toml.dump(config, f)
90
+
91
+ theme_toml = self.config_path.parent / "theme.toml"
92
+ with open(theme_toml, "w", encoding="utf-8") as f:
93
+ toml.dump(
94
+ dict(
95
+ theme_name=self.themes[theme_name]["name"],
96
+ theme_poem=self.themes[theme_name]["poem"],
97
+ ),
98
+ f,
99
+ )
100
+
101
+ return True
102
+
103
+ except Exception as e:
104
+ st.error(f"应用主题失败: {e}")
105
+ return False
106
+
107
+ def render_theme_selector(self):
108
+ """渲染主题选择器UI"""
109
+ if not self.themes:
110
+ st.warning("未找到可用主题")
111
+ return
112
+
113
+ theme_names = list(self.themes.keys())
114
+ current_theme = self.get_current_theme()
115
+
116
+ # 确定当前选中的索引
117
+ current_index = 0
118
+ if current_theme and current_theme in theme_names:
119
+ current_index = theme_names.index(current_theme)
120
+
121
+ # 主题选择下拉菜单
122
+ selected_theme = st.selectbox(
123
+ "选择主题",
124
+ options=theme_names,
125
+ index=current_index,
126
+ format_func=lambda x: self.themes[x]["name"],
127
+ key="theme_selector_widget",
128
+ label_visibility="collapsed",
129
+ )
130
+
131
+ # 如果选择了新主题
132
+ if selected_theme != current_theme:
133
+ if self.apply_theme(selected_theme):
134
+ # 显示主题诗句
135
+ theme_poem = self.themes[selected_theme]["poem"]
136
+ if theme_poem:
137
+ st.toast(f"✨ {theme_poem}", icon="🎨")
138
+ else:
139
+ st.toast(f"已切换到主题: {selected_theme}", icon="🎨")
140
+
141
+ time.sleep(3)
142
+ # 延迟重新运行以应用主题
143
+ st.rerun()
144
+
145
+ return selected_theme
146
+
147
+
148
+ # 全局主题选择器实例
149
+ _theme_selector = None
150
+
151
+
152
+ def get_theme_selector():
153
+ """获取全局主题选择器实例"""
154
+ global _theme_selector
155
+ if _theme_selector is None:
156
+ _theme_selector = ThemeSelector()
157
+ return _theme_selector
158
+
159
+
160
+ def render_theme_selector():
161
+ """便捷函数:渲染主题选择器"""
162
+ return get_theme_selector().render_theme_selector()
{src → src copy}/utils.py RENAMED
File without changes
src/.streamlit/config.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [theme]
2
+ base = "light"
3
+ baseFontSize = 15
4
+ primaryColor = "#FF5F7E"
5
+ backgroundColor = "#F9FAFB"
6
+ secondaryBackgroundColor = "#F0F4F8"
7
+ textColor = "#1F2937"
8
+ linkColor = "#2563EB"
9
+ borderColor = "#D1D5DB"
10
+ showWidgetBorder = false
11
+ baseRadius = "0.3rem"
12
+ font = "Poppins"
13
+
14
+ [theme.sidebar]
15
+ backgroundColor = "#FFFFFF"
16
+ secondaryBackgroundColor = "#F3F4F6"
17
+ borderColor = "#D1D5DB"
src/.streamlit/theme.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ theme_name = "花明 (Flowers Bright) 🌸"
2
+ theme_poem = "🌸「浅色但不苍白,明亮而不过曝,柔和中有力量」"
src/.streamlit/themes/antropic.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ theme_name = "Antropic"
2
+
3
+ [theme]
4
+ primaryColor = "#bb5a38"
5
+ backgroundColor = "#f4f3ed"
6
+ secondaryBackgroundColor = "#ecebe3"
7
+ textColor = "#3d3a2a"
8
+ linkColor = "#3d3a2a"
9
+ borderColor = "#d3d2ca"
10
+ showWidgetBorder = true
11
+ baseRadius = "0.6rem"
12
+ font = "SpaceGrotesk"
13
+ headingFont = "SpaceGroteskHeader"
14
+ codeFont = "SpaceMono"
15
+ codeBackgroundColor = "#ecebe4"
16
+ showSidebarBorder = true
17
+
18
+ [theme.sidebar]
19
+ backgroundColor = "#e8e7dd"
20
+ secondaryBackgroundColor = "#ecebe3"
src/.streamlit/themes/flowers_bright.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ theme_name = "花明 (Flowers Bright) 🌸"
2
+ theme_poem = "🌸「浅色但不苍白,明亮而不过曝,柔和中有力量」"
3
+
4
+ [theme]
5
+ base = "light"
6
+ baseFontSize = 15
7
+ primaryColor = "#FF5F7E" # 樱花粉
8
+ backgroundColor = "#F9FAFB" # 轻柔灰白
9
+ secondaryBackgroundColor = "#F0F4F8" # 卡片淡蓝灰
10
+ textColor = "#1F2937" # 蓝黑灰
11
+ linkColor = "#2563EB" # 浅蓝色链接
12
+ borderColor = "#D1D5DB" # 卡片分界线
13
+ showWidgetBorder = false
14
+ baseRadius = "0.3rem"
15
+ font = "Poppins" # 保持现代圆润感
16
+
17
+ [theme.sidebar]
18
+ backgroundColor = "#FFFFFF"
19
+ secondaryBackgroundColor = "#F3F4F6"
20
+ borderColor = "#D1D5DB"
src/.streamlit/themes/spotify.toml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ theme_name = "Spotify"
2
+
3
+ [theme]
4
+ base = "dark"
5
+ baseFontSize = 15
6
+ primaryColor = "#1ED760"
7
+ backgroundColor = "#121212"
8
+ secondaryBackgroundColor = "#333333"
9
+ textColor = "#FFFFFF"
10
+ linkColor = "#9D9D9D"
11
+ borderColor = "#7F7F7F"
12
+ showWidgetBorder = false
13
+ baseRadius = "0.3rem"
14
+ font = "Poppins"
15
+
16
+ [theme.sidebar]
17
+ backgroundColor = "#000000"
18
+ secondaryBackgroundColor = "#333333"
19
+ borderColor = "#696969"
src/.streamlit/themes/willows_dark.toml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ theme_name = "柳暗 (Willows Dark) 🌒"
2
+ theme_poem = "🌒「深而不死黑,蓝而不夺目,静而不沉闷」柳影婆娑之下,代码悄然生长。"
3
+
4
+ [theme]
5
+ base = "dark"
6
+ baseFontSize = 15
7
+ primaryColor = "#6EA8FE" # 柔和的蓝色光感
8
+ backgroundColor = "#0D1117" # 深夜蓝黑
9
+ secondaryBackgroundColor = "#1A1F2B" # 柔和灰蓝
10
+ textColor = "#D1D5DB" # 清晰柔白
11
+ linkColor = "#B8C0FF" # 柔紫蓝
12
+ borderColor = "#2E3440" # 北极灰边
13
+ showWidgetBorder = false
14
+ baseRadius = "0.3rem"
15
+ font = "JetBrains Mono" # 强科技感
16
+
17
+ [theme.sidebar]
18
+ backgroundColor = "#0A0A0A"
19
+ secondaryBackgroundColor = "#1A1A1A"
20
+ borderColor = "#2E3440"
src/data_loader.py CHANGED
@@ -7,7 +7,9 @@ import streamlit as st
7
  from functools import lru_cache # For non-Streamlit specific caching if needed
8
 
9
  # Assuming utils.py is in the same directory
10
- from utils import DATA_ROOT_PATH # Used for ensuring directory exists
 
 
11
 
12
  # --- Cache Clearing Functions ---
13
  # These are more specific cache clearing functions that can be called by model methods.
 
7
  from functools import lru_cache # For non-Streamlit specific caching if needed
8
 
9
  # Assuming utils.py is in the same directory
10
+ from infra import (
11
+ DATA_ROOT_PATH,
12
+ ) # Used for ensuring directory exists
13
 
14
  # --- Cache Clearing Functions ---
15
  # These are more specific cache clearing functions that can be called by model methods.
src/infra.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # utils.py
2
+ from pathlib import Path
3
+ from enum import Enum, auto
4
+
5
+ # Base path for studies and trials.
6
+ # Streamlit apps are typically run from their root directory.
7
+ # If your app.py is in 'src/', and 'data/' is at the same level as 'src/',
8
+ # then Path("./data") from app.py's perspective would be Path("../data").
9
+ # For simplicity, assuming data is relative to where streamlit run is executed,
10
+ # or you adjust this path accordingly.
11
+ DATA_ROOT_PATH = Path("./data").resolve()
12
+
13
+
14
+ class AppMode(Enum):
15
+ VIEWING = auto()
16
+ # Potentially other modes like COMPARISON, EDITING etc.
src/logo.png ADDED

Git LFS Details

  • SHA256: 6184b7d0eb9265bf8928c9cd9731dec91589eb57aafc97e4d5be804bf1588f16
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
src/streamlit_app.py CHANGED
@@ -20,7 +20,7 @@ st.logo("logo.png", icon_image="logo.png")
20
 
21
  # 导入重构后的模块
22
  try:
23
- from utils import DATA_ROOT_PATH, AppMode
24
  from data_models import Study, Trial # Study, Trial will be used
25
  from data_loader import discover_studies_cached, ensure_data_directory_exists
26
  from theme_selector import render_theme_selector # 新增:导入主题选择器
@@ -441,46 +441,84 @@ if current_study and current_trial:
441
  break
442
 
443
  # 显示metric组件
444
- with (
445
- metric_cols[idx]
446
- if len(all_tracks) > 1
447
- else metric_cols[0]
448
- ):
449
- if current_value is not None:
450
- # 确定label
451
- if track is not None:
452
- if current_step_found != current_step:
453
- label = f"{track} (Step {current_step_found})"
 
 
 
 
 
 
454
  else:
455
- label = f"{track}"
 
 
 
 
 
 
 
 
 
 
 
 
 
456
  else:
457
- if current_step_found != current_step:
458
- label = f"当前值 (Step {current_step_found})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
  else:
460
- label = (
461
- f"当前值 (Step {current_step})"
462
- )
463
-
464
- st.metric(
465
- label=label,
466
- value=f"{current_value:.4f}",
467
- delta=f"{delta_value:.4f}"
468
- if delta_value is not None
469
- else None,
470
- )
471
- else:
472
- # 没有找到任何数据
473
- track_label = (
474
- track if track is not None else "数据"
475
- )
476
- st.metric(
477
- label=f"{track_label}",
478
- value="无数据",
479
- delta=None,
480
- )
 
 
 
481
 
482
  except Exception as e:
483
  st.warning(f"计算指标值时出错: {e}")
 
484
 
485
  try:
486
  # 创建 Plotly 图表
 
20
 
21
  # 导入重构后的模块
22
  try:
23
+ from infra import DATA_ROOT_PATH, AppMode
24
  from data_models import Study, Trial # Study, Trial will be used
25
  from data_loader import discover_studies_cached, ensure_data_directory_exists
26
  from theme_selector import render_theme_selector # 新增:导入主题选择器
 
441
  break
442
 
443
  # 显示metric组件
444
+ # print(metric_cols)
445
+ # print(metric_cols[idx])
446
+ # print(len(metric_cols), len(all_tracks))
447
+ metric_col = metric_cols[0] if len(metric_cols) == 1 else metric_cols[idx]
448
+
449
+ try:
450
+ with (
451
+ metric_col
452
+ ):
453
+ if current_value is not None:
454
+ # 确定label
455
+ if track is not None:
456
+ if current_step_found != current_step:
457
+ label = f"{track} (Step {current_step_found})"
458
+ else:
459
+ label = f"{track}"
460
  else:
461
+ if current_step_found != current_step:
462
+ label = f"当前值 (Step {current_step_found})"
463
+ else:
464
+ label = (
465
+ f"当前值 (Step {current_step})"
466
+ )
467
+
468
+ st.metric(
469
+ label=label,
470
+ value=f"{current_value:.4f}",
471
+ delta=f"{delta_value:.4f}"
472
+ if delta_value is not None
473
+ else None,
474
+ )
475
  else:
476
+ # 没有找到任何数据
477
+ track_label = (
478
+ track if track is not None else "数据"
479
+ )
480
+ st.metric(
481
+ label=f"{track_label}",
482
+ value="无数据",
483
+ delta=None,
484
+ )
485
+ except Exception as e:
486
+ if current_value is not None:
487
+ # 确定label
488
+ if track is not None:
489
+ if current_step_found != current_step:
490
+ label = f"{track} (Step {current_step_found})"
491
+ else:
492
+ label = f"{track}"
493
  else:
494
+ if current_step_found != current_step:
495
+ label = f"当前值 (Step {current_step_found})"
496
+ else:
497
+ label = (
498
+ f"当前值 (Step {current_step})"
499
+ )
500
+
501
+ st.metric(
502
+ label=label,
503
+ value=f"{current_value:.4f}",
504
+ delta=f"{delta_value:.4f}"
505
+ if delta_value is not None
506
+ else None,
507
+ )
508
+ else:
509
+ # 没有找到任何数据
510
+ track_label = (
511
+ track if track is not None else "数据"
512
+ )
513
+ st.metric(
514
+ label=f"{track_label}",
515
+ value="无数据",
516
+ delta=None,
517
+ )
518
 
519
  except Exception as e:
520
  st.warning(f"计算指标值时出错: {e}")
521
+ raise e
522
 
523
  try:
524
  # 创建 Plotly 图表
src/test.py DELETED
@@ -1,42 +0,0 @@
1
- import plotly.graph_objects as go
2
- from plotly.subplots import make_subplots
3
- import numpy as np
4
-
5
- # 创建示例数据
6
- x = np.linspace(0, 10, 100)
7
- y1 = np.sin(x)
8
- y2 = np.cos(x)
9
-
10
- # 创建两个 FigureWidget 图表
11
- fig1 = go.FigureWidget(data=[go.Scatter(x=x, y=y1, mode="lines", name="sin(x)")])
12
- fig2 = go.FigureWidget(data=[go.Scatter(x=x, y=y2, mode="lines", name="cos(x)")])
13
-
14
-
15
- # 定义悬停事件的回调函数
16
- def hover_fn(trace, points, state):
17
- if points.xs:
18
- hover_x = points.xs[0]
19
- with fig2.batch_update():
20
- # 在第二个图表上添加垂直线
21
- fig2.layout.shapes = [
22
- dict(
23
- type="line",
24
- x0=hover_x,
25
- x1=hover_x,
26
- y0=min(y2),
27
- y1=max(y2),
28
- line=dict(color="red", dash="dot"),
29
- )
30
- ]
31
-
32
-
33
- # 为第一个图表的第一个 trace 注册悬停事件
34
- fig1.data[0].on_hover(hover_fn)
35
-
36
- # 显示图表
37
- import streamlit as st
38
-
39
- st.plotly_chart(fig1, use_container_width=True)
40
- st.plotly_chart(fig2, use_container_width=True)
41
-
42
- # %%
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/visualizers/base_visualizer.py DELETED
@@ -1,254 +0,0 @@
1
- # src/visualizers/base_visualizer.py
2
- from abc import ABC, abstractmethod
3
- from pathlib import Path
4
- from typing import Dict, Type, Any, Optional, Callable, List
5
- import streamlit as st
6
-
7
- # --- Component Registry ---
8
- VISUALIZER_REGISTRY: Dict[str, Type["VisualizationComponent"]] = {}
9
-
10
- def register_visualizer(name: str):
11
- """
12
- 一个装饰器,用于将可视化组件类注册到全局注册表。
13
- A decorator to register a visualization component class to the global registry.
14
- """
15
- def decorator(cls: Type["VisualizationComponent"]):
16
- if name in VISUALIZER_REGISTRY:
17
- # 在调试或开发模式下可能是警告,生产模式下可能是错误
18
- # In debug/dev mode this might be a warning, in production an error
19
- print(f"警告: 可视化组件 '{name}' 已被注册,将被覆盖。Visualizer '{name}' already registered. Will be overridden.")
20
- VISUALIZER_REGISTRY[name] = cls
21
- return cls
22
- return decorator
23
-
24
- def get_visualizer_class(type_name: str) -> Optional[Type["VisualizationComponent"]]:
25
- """
26
- 从注册表中获取组件类。
27
- Gets a component class from the registry.
28
- """
29
- return VISUALIZER_REGISTRY.get(type_name)
30
-
31
- # --- Abstract Base Class for Visualization Components ---
32
- class VisualizationComponent(ABC):
33
- """
34
- 可视化组件的抽象基类。
35
- Abstract base class for all visualization components.
36
- """
37
- def __init__(self,
38
- component_instance_id: str,
39
- trial_root_path: Path,
40
- # data_asset_info: Dict[str, Any], # 描述此组件主要关联的数据资产信息 (来自清单)
41
- # Describes the data asset this component is primarily associated with (from manifest)
42
- # ^ 将被更通用的 data_sources_map 替代
43
- data_sources_map: Dict[str, Dict[str, Any]], # Key: 逻辑数据源名称, Value: 数据资产信息字典
44
- # Key: logical data source name, Value: data asset info dict
45
- # e.g., {"main_scalar_data": asset_info_for_loss, "reference_images": asset_info_for_images}
46
- component_specific_config: Dict[str, Any] = None
47
- ):
48
- """
49
- 初始化可视化组件。
50
- Initializes the visualization component.
51
-
52
- Args:
53
- component_instance_id (str): 此组件在仪表盘上的唯一实例ID。
54
- Unique instance ID for this component on the dashboard.
55
- trial_root_path (Path): 此组件所属的Trial的根目录路径。
56
- Root directory path of the Trial this component belongs to.
57
- data_sources_map (Dict[str, Dict[str, Any]]):
58
- 一个字典,映射逻辑数据源名称到具体的数据资产信息。
59
- 数据资产信息字典通常包含 'asset_id', 'display_name', 'data_type', 'path',
60
- 以及其他从 _trial_manifest.toml 中解析得到的元数据。
61
- A dictionary mapping logical data source names to specific data asset information.
62
- The data asset info dict typically contains 'asset_id', 'display_name', 'data_type', 'path',
63
- and other metadata parsed from _trial_manifest.toml.
64
- component_specific_config (Dict[str, Any], optional):
65
- 特定于此组件实例的配置 (例如,图表标题、颜色等)。
66
- Configuration specific to this component instance (e.g., chart title, color, etc.).
67
- """
68
- self.component_instance_id = component_instance_id
69
- self.trial_root_path = trial_root_path
70
- self.data_sources_map = data_sources_map
71
- self.config = component_specific_config if component_specific_config is not None else {}
72
-
73
- # 组件自身持久化配置或小型数据的路径
74
- # Path for the component to persist its own configuration or small data
75
- self.component_private_storage_path = self.trial_root_path / "visualizers_data" / self.component_instance_id
76
- self.component_private_storage_path.mkdir(parents=True, exist_ok=True)
77
-
78
- self._current_global_step: Optional[int] = None
79
- self._on_global_step_change_request: Optional[Callable[[int], None]] = None
80
- self._all_available_steps: Optional[List[int]] = None # 由主应用或数据加载器填充
81
-
82
- def _get_data_asset_info(self, logical_name: str = "default") -> Optional[Dict[str, Any]]:
83
- """
84
- 辅助方法,获取指定逻辑名称的数据资产信息。
85
- Helper method to get data asset info for a given logical name.
86
- 如果组件只处理一个主要数据源,可以使用 "default" 或在初始化时指定。
87
- If a component handles one primary data source, "default" or a specific name can be used.
88
- """
89
- return self.data_sources_map.get(logical_name)
90
-
91
- def _get_data_asset_path(self, logical_name: str = "default") -> Optional[Path]:
92
- """获取指定逻辑数据源的绝对路径。Gets the absolute path for a given logical data source."""
93
- asset_info = self._get_data_asset_info(logical_name)
94
- if asset_info and "path" in asset_info:
95
- # 路径在清单中是相对于trial_root_path的
96
- # Path in manifest is relative to trial_root_path
97
- return (self.trial_root_path / asset_info["path"]).resolve()
98
- return None
99
-
100
- def configure_global_step_interaction(self,
101
- current_step: Optional[int],
102
- all_available_steps: Optional[List[int]],
103
- on_step_change_request_callback: Optional[Callable[[int], None]]):
104
- """
105
- 由主应用调用,以配置与全局步骤相关的交互。
106
- Called by the main application to configure global step related interactions.
107
-
108
- Args:
109
- current_step (Optional[int]): 当前选中的全局步骤。
110
- The currently selected global step.
111
- all_available_steps (Optional[List[int]]): 此Trial中所有可用的全局步骤列表 (已排序)。
112
- A sorted list of all available global steps in this Trial.
113
- on_step_change_request_callback (Optional[Callable[[int], None]]):
114
- 当组件希望更改全局步骤时调用的回调函数。
115
- Callback function to be called when the component wishes to change the global step.
116
- """
117
- self._current_global_step = current_step
118
- self._all_available_steps = sorted(list(set(all_available_steps))) if all_available_steps else []
119
- self._on_global_step_change_request = on_step_change_request_callback
120
-
121
- def _request_global_step_change(self, new_step: int) -> None:
122
- """
123
- 组件内部调用此方法来请求更改全局共享的global_step。
124
- Component calls this internally to request a change to the shared global_step.
125
- """
126
- if self._on_global_step_change_request:
127
- if self._all_available_steps and new_step not in self._all_available_steps:
128
- # 如果请求的步骤无效,可以选择寻找最近的有效步骤或忽略
129
- # If requested step is invalid, can choose to find nearest valid step or ignore
130
- # For now, let's assume the interaction (e.g., chart click) provides a valid step from its data
131
- print(f"警告: 组件 {self.component_instance_id} 请求了一个无效的全局步骤 {new_step}。")
132
- # Potentially find closest:
133
- # if self._all_available_steps:
134
- # new_step = min(self._all_available_steps, key=lambda x: abs(x - new_step))
135
-
136
- self._on_global_step_change_request(new_step)
137
- else:
138
- st.warning(f"组件 {self.component_instance_id}: 尝试更改全局步骤,但未设置回调。")
139
-
140
- def _get_closest_available_step(self, target_step: Optional[int]) -> Optional[int]:
141
- """
142
- 如果目标步骤无效或数据在该步骤不可用,则查找最近的可用步骤。
143
- Finds the closest available step if the target step is invalid or data isn't available at that step.
144
- 组件的子类可以覆盖此逻辑以适应其特定的数据可用性。
145
- Subclasses can override this logic for their specific data availability.
146
- """
147
- if target_step is None:
148
- return None
149
- if not self._all_available_steps:
150
- return None
151
- if target_step in self._all_available_steps:
152
- return target_step
153
-
154
- # 寻找最接近的步骤 (简单的实现)
155
- # Find the closest step (simple implementation)
156
- try:
157
- closest = min(self._all_available_steps, key=lambda x: abs(x - target_step))
158
- return closest
159
- except ValueError: # _all_available_steps为空
160
- return None
161
-
162
-
163
- @abstractmethod
164
- def load_data(self) -> None:
165
- """
166
- 加载此组件渲染所需的数据。
167
- Load data required by this component for rendering.
168
- 实现者应使用 self._get_data_asset_path() 来获取数据文件路径。
169
- Implementers should use self._get_data_asset_path() to get data file paths.
170
- 数据加载后通常存储在实例变量中。
171
- Loaded data is typically stored in instance variables.
172
- """
173
- pass
174
-
175
- @abstractmethod
176
- def render(self) -> None:
177
- """
178
- 将组件渲染为Streamlit UI元素。
179
- Renders the component as Streamlit UI elements.
180
- 应使用 self._current_global_step (可能通过 self._get_closest_available_step 调整) 来显示对应步骤的数据。
181
- Should use self._current_global_step (possibly adjusted by self._get_closest_available_step)
182
- to display data for the corresponding step.
183
- 任何可以触发全局步骤更改的交互都应调用 self._request_global_step_change()。
184
- Any interaction that can trigger a global step change should call self._request_global_step_change().
185
- """
186
- pass
187
-
188
- @classmethod
189
- @abstractmethod
190
- def can_handle_data_types(cls, data_type_names: List[str]) -> bool:
191
- """
192
- 类方法:判断此组件类型是否能处理清单中声明的一个或多个数据类型。
193
- Class method: Determines if this component type can handle one or more data types
194
- declared in the manifest.
195
-
196
- Args:
197
- data_type_names (List[str]): 从数据资产清单中获取的数据类型名称列表。
198
- A list of data type names from a data asset manifest.
199
- (通常,主应用会为每个数据资产调用此方法,列表只包含一个元素)
200
- (Usually, the main app calls this for each data asset, so the list has one element)
201
-
202
- Returns:
203
- bool: True 如果此组件可以处理至少一种给定的数据类型。
204
- True if this component can handle at least one of the given data types.
205
- """
206
- pass
207
-
208
- @classmethod
209
- def get_display_name(cls) -> str:
210
- """
211
- 类方法:返回此组件类型的用户友好显示名称。
212
- Class method: Returns a user-friendly display name for this component type.
213
- """
214
- # 简单的实现:将类名从CamelCase转换为带空格的标题
215
- # Simple implementation: Convert CamelCase class name to space-separated title
216
- name = cls.__name__
217
- if name.endswith("Visualizer"):
218
- name = name[:-len("Visualizer")]
219
- s1 = VISUALIZER_REGISTRY.get(name, name) # Fallback to class name if not in registry (should not happen with decorator)
220
- # Add spaces before capital letters (simple version)
221
- import re
222
- return re.sub(r'(?<!^)(?=[A-Z])', ' ', s1)
223
-
224
-
225
- def save_component_config(self) -> None:
226
- """
227
- 将当前组件的特定配置 (self.config) 保存到其私有存储路径。
228
- Saves the current component-specific configuration (self.config) to its private storage path.
229
- """
230
- config_file = self.component_private_storage_path / "_component_config.toml"
231
- try:
232
- import tomli_w # 确保已安装 Ensure tomli_w is installed
233
- with open(config_file, "wb") as f:
234
- tomli_w.dump(self.config, f)
235
- # st.toast(f"组件 '{self.component_instance_id}' 配置已保存。")
236
- except Exception as e:
237
- st.error(f"保存组件 '{self.component_instance_id}' 配置失败: {e}")
238
-
239
- def load_component_config(self) -> None:
240
- """
241
- 从其私有存储路径加载组件的特定配置,并更新 self.config。
242
- Loads the component-specific configuration from its private storage path and updates self.config.
243
- """
244
- config_file = self.component_private_storage_path / "_component_config.toml"
245
- if config_file.exists():
246
- try:
247
- import tomli # 确保已安装 Ensure tomli is installed
248
- with open(config_file, "rb") as f:
249
- loaded_config = tomli.load(f)
250
- self.config.update(loaded_config) # 合并加载的配置 Merge loaded config
251
- # st.toast(f"组件 '{self.component_instance_id}' 配置已加载。")
252
- except Exception as e:
253
- st.error(f"加载组件 '{self.component_instance_id}' 配置失败: {e}")
254
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/visualizers/ide_app.py DELETED
@@ -1,303 +0,0 @@
1
- # src/debugger/ide_app.py
2
- import streamlit as st
3
- from pathlib import Path
4
- import tempfile
5
- import json
6
- import shutil
7
- from typing import Dict, Any, Optional, Type, List
8
- import tomli # 确保导入 tomli Ensure tomli is imported
9
-
10
- # 动态确定根目录,以便能够导入src下的模块
11
- # Dynamically determine root directory to import modules from src
12
- try:
13
- current_file_path = Path(__file__).resolve()
14
- # 假设 ide_app.py 在 src/debugger/ 目录下
15
- # Assumes ide_app.py is in src/debugger/
16
- # src_path = current_file_path.parent.parent
17
- # import sys
18
- # if str(src_path) not in sys.path:
19
- # sys.path.insert(0, str(src_path))
20
- except NameError:
21
- # project_root = Path(".") # 默认为当前工作目录
22
- pass # 在某些环境中 __file__ 可能未定义 In some environments __file__ might be undefined
23
-
24
- try:
25
- from visualizers.base_visualizer import (
26
- VISUALIZER_REGISTRY,
27
- get_visualizer_class,
28
- VisualizationComponent
29
- )
30
- # 确保所有组件都被导入和注册
31
- # Ensure all components are imported and registered
32
- import visualizers.scalar_dashboard_visualizer
33
-
34
- except ImportError as e:
35
- st.error(
36
- "错误:无法导入可视化组件模块。请确保此IDE应用位于正确的项目结构中,"
37
- f"并且 `visualizers` 目录及其内容可访问。\n详细信息: {e}"
38
- "Error: Could not import visualization component modules. Ensure this IDE app is in the correct "
39
- f"project structure and the `visualizers` directory and its contents are accessible.\nDetails: {e}"
40
- )
41
- st.stop()
42
-
43
- # --- 应用标题和配置 ---
44
- st.set_page_config(layout="wide", page_title="Flowillower - 组件IDE (Component IDE)")
45
- st.title("🔬 Flowillower - 可视化组件IDE (Visualization Component IDE)")
46
- st.markdown("在此环境中独立测试、调试和预览您的可视化组件。")
47
-
48
- # --- 会话状态初始化 ---
49
- if "selected_visualizer_type_name" not in st.session_state:
50
- st.session_state.selected_visualizer_type_name = None
51
- if "component_instance_id" not in st.session_state:
52
- st.session_state.component_instance_id = "test_instance_001"
53
- if "trial_root_path_str" not in st.session_state:
54
- st.session_state.trial_root_path_str = tempfile.mkdtemp(prefix="flowillower_ide_trial_")
55
- if "component_specific_config_str" not in st.session_state:
56
- st.session_state.component_specific_config_str = "{}"
57
- if "active_visualizer_instance" not in st.session_state:
58
- st.session_state.active_visualizer_instance = None
59
- if "generated_data_sources_map" not in st.session_state:
60
- st.session_state.generated_data_sources_map = None
61
- if "current_simulated_global_step" not in st.session_state:
62
- st.session_state.current_simulated_global_step = 0 # 默认值 Default value
63
- if "all_simulated_steps" not in st.session_state:
64
- st.session_state.all_simulated_steps = []
65
-
66
-
67
- def cleanup_temp_dir(path_str):
68
- try:
69
- if path_str and Path(path_str).exists() and "flowillower_ide_trial_" in path_str:
70
- shutil.rmtree(path_str)
71
- st.toast(f"临时目录已清理: {path_str}")
72
- except Exception as e:
73
- st.warning(f"清理临时目录失败 {path_str}: {e}")
74
-
75
- with st.sidebar:
76
- st.header("组件选择与配置")
77
-
78
- registered_type_names = list(VISUALIZER_REGISTRY.keys())
79
- if not registered_type_names:
80
- st.error("错误:没有已注册的可视化组件类型。")
81
- st.stop()
82
-
83
- if st.session_state.selected_visualizer_type_name is None and registered_type_names:
84
- st.session_state.selected_visualizer_type_name = registered_type_names[0]
85
-
86
- st.session_state.selected_visualizer_type_name = st.selectbox(
87
- "选择可视化组件类型",
88
- options=registered_type_names,
89
- index=registered_type_names.index(st.session_state.selected_visualizer_type_name)
90
- if st.session_state.selected_visualizer_type_name in registered_type_names else 0,
91
- help="选择您想要测试的可视化组件。"
92
- )
93
-
94
- SelectedVisualizerClass: Optional[Type[VisualizationComponent]] = get_visualizer_class(st.session_state.selected_visualizer_type_name)
95
-
96
- if SelectedVisualizerClass:
97
- st.caption(f"显示名称: `{SelectedVisualizerClass.get_display_name()}`")
98
- else:
99
- st.error(f"无法加载组件类 '{st.session_state.selected_visualizer_type_name}'。")
100
- st.stop()
101
-
102
- st.session_state.component_instance_id = st.text_input(
103
- "组件实例ID", value=st.session_state.component_instance_id
104
- )
105
- st.markdown(f"**临时Trial根路径:** `{st.session_state.trial_root_path_str}`")
106
-
107
- example_data_target_dir = Path(st.session_state.trial_root_path_str) / "example_assets_for_ide"
108
-
109
- if st.button("生成示例数据"):
110
- if SelectedVisualizerClass:
111
- try:
112
- if example_data_target_dir.exists():
113
- shutil.rmtree(example_data_target_dir)
114
- example_data_target_dir.mkdir(parents=True, exist_ok=True)
115
-
116
- st.session_state.generated_data_sources_map = SelectedVisualizerClass.generate_example_data(
117
- example_data_path=example_data_target_dir
118
- )
119
- st.success(f"'{SelectedVisualizerClass.get_display_name()}' 的示例数据已生成。")
120
-
121
- temp_all_steps = set()
122
- if st.session_state.generated_data_sources_map:
123
- for ds_name, ds_info in st.session_state.generated_data_sources_map.items():
124
- if "path" in ds_info:
125
- try:
126
- # 路径是相对于 trial_root_path 的
127
- # Path is relative to trial_root_path
128
- full_path = Path(st.session_state.trial_root_path_str) / ds_info["path"]
129
- if full_path.exists() and full_path.suffix == ".toml":
130
- with open(full_path, "rb") as f:
131
- d = tomli.load(f)
132
- if "metrics" in d and isinstance(d["metrics"], list):
133
- for point in d["metrics"]:
134
- if "global_step" in point:
135
- temp_all_steps.add(int(point["global_step"]))
136
- except Exception as e_load:
137
- st.warning(f"解析示例数据中的步骤时出错 ({ds_info['path']}): {e_load}")
138
-
139
- st.session_state.all_simulated_steps = sorted(list(temp_all_steps))
140
- if not st.session_state.all_simulated_steps: # 如果没有解析到步骤,至少放一个0
141
- st.session_state.all_simulated_steps = [0]
142
-
143
- # 更新 current_simulated_global_step 为有效值
144
- # Update current_simulated_global_step to a valid value
145
- if st.session_state.all_simulated_steps:
146
- st.session_state.current_simulated_global_step = st.session_state.all_simulated_steps[0]
147
- else: # 理论上不会到这里,因为上面保证了至少有[0] Theoretically won't reach here as [0] is guaranteed above
148
- st.session_state.current_simulated_global_step = 0
149
- st.rerun() # 重新运行以更新UI中的步骤滑块 Rerun to update step slider in UI
150
-
151
- except Exception as e:
152
- st.error(f"生成示例数据失败: {e}")
153
- st.exception(e)
154
- st.session_state.generated_data_sources_map = None
155
- else:
156
- st.warning("请先选择一个组件类型。")
157
-
158
- st.session_state.component_specific_config_str = st.text_area(
159
- "组件特定配置 (JSON)", value=st.session_state.component_specific_config_str, height=100
160
- )
161
-
162
- if st.button("🚀 实例化组件", type="primary"):
163
- if SelectedVisualizerClass and st.session_state.component_instance_id and st.session_state.generated_data_sources_map:
164
- try:
165
- specific_config = json.loads(st.session_state.component_specific_config_str)
166
- st.session_state.active_visualizer_instance = SelectedVisualizerClass(
167
- component_instance_id=st.session_state.component_instance_id,
168
- trial_root_path=Path(st.session_state.trial_root_path_str),
169
- data_sources_map=st.session_state.generated_data_sources_map,
170
- component_specific_config=specific_config
171
- )
172
- st.success(f"组件 '{st.session_state.component_instance_id}' 已实例化。")
173
- active_viz_instance_for_load = st.session_state.active_visualizer_instance
174
- active_viz_instance_for_load.load_data()
175
-
176
- # 从组件加载数据后获取实际的all_available_steps
177
- # Get actual all_available_steps after component loads data
178
- if hasattr(active_viz_instance_for_load, '_all_available_steps') and active_viz_instance_for_load._all_available_steps:
179
- st.session_state.all_simulated_steps = active_viz_instance_for_load._all_available_steps
180
- if st.session_state.current_simulated_global_step not in st.session_state.all_simulated_steps:
181
- st.session_state.current_simulated_global_step = st.session_state.all_simulated_steps[0]
182
-
183
- active_viz_instance_for_load.configure_global_step_interaction(
184
- current_step=st.session_state.current_simulated_global_step,
185
- all_available_steps=st.session_state.all_simulated_steps,
186
- on_step_change_request_callback=lambda step: st.session_state.update({"current_simulated_global_step": step})
187
- )
188
- st.rerun() # 确保UI更新 Ensure UI updates
189
-
190
- except Exception as e:
191
- st.error(f"实例化组件失败: {e}")
192
- st.exception(e)
193
- st.session_state.active_visualizer_instance = None
194
- else:
195
- st.warning("请先选择组件类型,输入实例ID,并生成示例数据。")
196
-
197
- st.header("组件预览与交互")
198
- active_viz_instance: Optional[VisualizationComponent] = st.session_state.active_visualizer_instance
199
-
200
- if active_viz_instance:
201
- st.markdown(f"**当前活动组件:** `{active_viz_instance.component_instance_id}` "
202
- f"(类型: `{SelectedVisualizerClass.get_display_name() if SelectedVisualizerClass else 'N/A'}`)")
203
- st.markdown(f"**Trial根路径:** `{active_viz_instance.trial_root_path}`")
204
-
205
- st.markdown("---")
206
- st.subheader("全局步骤模拟")
207
- col_step1, col_step2 = st.columns([3,1])
208
-
209
- # --- 全局步骤模拟 (FIXED RangeError) ---
210
- # --- Global Step Simulation (FIXED RangeError) ---
211
- current_step_for_ui = st.session_state.current_simulated_global_step
212
- all_steps_for_ui = st.session_state.all_simulated_steps
213
-
214
- # 确保 current_step_for_ui 在 all_steps_for_ui 中 (如果 all_steps_for_ui 非空)
215
- # Ensure current_step_for_ui is in all_steps_for_ui (if all_steps_for_ui is not empty)
216
- if all_steps_for_ui and current_step_for_ui not in all_steps_for_ui:
217
- # 尝试寻找最近的,或者直接用第一个/最后一个
218
- # Try to find the closest, or just use the first/last
219
- current_step_for_ui = min(all_steps_for_ui, key=lambda x:abs(x-current_step_for_ui)) if all_steps_for_ui else 0
220
- # st.session_state.current_simulated_global_step = current_step_for_ui # 避免在渲染中直接修改会话状态 Avoid direct session state modification in render path
221
-
222
- with col_step1:
223
- if all_steps_for_ui:
224
- if len(all_steps_for_ui) == 1:
225
- st.markdown(f"当前模拟全局步骤: **{all_steps_for_ui[0]}** (只有一步可用)")
226
- # 如果只有一步,确保会话状态也正确
227
- # If only one step, ensure session state is also correct
228
- if st.session_state.current_simulated_global_step != all_steps_for_ui[0]:
229
- st.session_state.current_simulated_global_step = all_steps_for_ui[0]
230
- # st.rerun() # 可能会导致循环,让 configure_global_step_interaction 处理 Might cause loop, let configure_global_step_interaction handle
231
- new_sim_step = all_steps_for_ui[0] # 保持一致 Keep consistent
232
- else: # 多于一个步骤 More than one step
233
- new_sim_step = st.select_slider(
234
- "当前模拟全局步骤",
235
- options=all_steps_for_ui,
236
- value=current_step_for_ui, # 使用已验证的值 Use validated value
237
- key="sim_step_slider_corrected"
238
- )
239
- else: # 没有从数据中解析到步骤 No steps parsed from data
240
- new_sim_step = st.number_input(
241
- "当前模拟全局步骤 (无可用步骤)",
242
- value=current_step_for_ui,
243
- key="sim_step_input_empty_corrected"
244
- )
245
-
246
- # 如果用户通过UI更改了步骤,则更新会话状态
247
- # If user changed step via UI, update session state
248
- if new_sim_step != st.session_state.current_simulated_global_step:
249
- st.session_state.current_simulated_global_step = new_sim_step
250
- # st.rerun() # select_slider/number_input 通常会在值变化时自动 rerun
251
-
252
- # 每次渲染都更新组件的步骤信息
253
- # Update component's step info on every render
254
- active_viz_instance.configure_global_step_interaction(
255
- current_step=st.session_state.current_simulated_global_step,
256
- all_available_steps=all_steps_for_ui, # 使用从数据加载的步骤 Use steps loaded from data
257
- on_step_change_request_callback=lambda step: st.session_state.update({"current_simulated_global_step": step}) # 更新会话状态,Streamlit会自动rerun
258
- )
259
-
260
- with col_step2:
261
- if st.button("🔄 重新加载数据"):
262
- try:
263
- active_viz_instance.load_data()
264
- st.toast("组件数据已重新加载。")
265
- # 重新加载数据后,可能需要更新 all_simulated_steps
266
- # After reloading data, all_simulated_steps might need an update
267
- if hasattr(active_viz_instance, '_all_available_steps') and active_viz_instance._all_available_steps:
268
- st.session_state.all_simulated_steps = active_viz_instance._all_available_steps
269
- # 确保当前步骤仍然有效
270
- # Ensure current step is still valid
271
- if st.session_state.current_simulated_global_step not in st.session_state.all_simulated_steps:
272
- st.session_state.current_simulated_global_step = st.session_state.all_simulated_steps[0] if st.session_state.all_simulated_steps else 0
273
-
274
- active_viz_instance.configure_global_step_interaction( # 再次配置以防步骤列表变化 Reconfigure in case step list changed
275
- current_step=st.session_state.current_simulated_global_step,
276
- all_available_steps=st.session_state.all_simulated_steps,
277
- on_step_change_request_callback=lambda step: st.session_state.update({"current_simulated_global_step": step})
278
- )
279
- st.rerun() # 强制刷新UI Force UI refresh
280
- except Exception as e:
281
- st.error(f"重新加载数据失败: {e}")
282
-
283
- st.markdown("---")
284
- st.subheader("渲染输出")
285
- try:
286
- with st.container(border=True):
287
- active_viz_instance.render()
288
- except Exception as e:
289
- st.error(f"渲染组件 '{active_viz_instance.component_instance_id}' 时出错: {e}")
290
- st.exception(e)
291
- else:
292
- st.info("请在侧边栏中选择一个组件类型,生成示例数据,然后点击“实例化组件”以开始调试。")
293
-
294
- st.sidebar.markdown("---")
295
- st.sidebar.caption(f"IDE 会话临时路径: {st.session_state.trial_root_path_str}")
296
- if st.sidebar.button("清理当前会话的临时Trial目录"):
297
- cleanup_temp_dir(st.session_state.trial_root_path_str)
298
- st.session_state.trial_root_path_str = tempfile.mkdtemp(prefix="flowillower_ide_trial_")
299
- st.session_state.active_visualizer_instance = None
300
- st.session_state.generated_data_sources_map = None
301
- st.session_state.all_simulated_steps = []
302
- st.session_state.current_simulated_global_step = 0
303
- st.rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/visualizers/scalar_dashboard_visualizer.py DELETED
@@ -1,333 +0,0 @@
1
- # src/visualizers/scalar_dashboard_visualizer.py
2
- import streamlit as st
3
- import pandas as pd
4
- import plotly.graph_objects as go
5
- import plotly.express as px
6
- import tomli
7
- import tomli_w
8
- from pathlib import Path
9
- from typing import Dict, Any, Optional, List, Tuple
10
-
11
- from .base_visualizer import VisualizationComponent, register_visualizer
12
-
13
- @register_visualizer(name="scalar_metrics_dashboard")
14
- class ScalarMetricsDashboardVisualizer(VisualizationComponent):
15
- """
16
- 一个用于显示多个标量指标(每个指标可能包含多个track)的仪表盘组件。
17
- 它会为每个指标在选定的global_step显示st.metric摘要,并绘制一个Plotly折线图。
18
- A dashboard component to display multiple scalar metrics, each potentially with multiple tracks.
19
- It shows st.metric summaries for the selected global_step and a Plotly line chart for each metric.
20
- """
21
-
22
- _raw_data: Optional[pd.DataFrame] = None
23
- _processed_metrics_data: Dict[str, pd.DataFrame] = None
24
-
25
- DEFAULT_CHARTS_PER_ROW = 2
26
-
27
- def __init__(self,
28
- component_instance_id: str,
29
- trial_root_path: Path,
30
- data_sources_map: Dict[str, Dict[str, Any]],
31
- component_specific_config: Dict[str, Any] = None):
32
- super().__init__(component_instance_id, trial_root_path, data_sources_map, component_specific_config)
33
- self._raw_data = None
34
- self._processed_metrics_data = {}
35
- self.load_component_config()
36
-
37
- self.charts_per_row = self.config.get("charts_per_row", self.DEFAULT_CHARTS_PER_ROW)
38
- self.chart_height = self.config.get("chart_height", 400)
39
-
40
- @classmethod
41
- def can_handle_data_types(cls, data_type_names: List[str]) -> bool:
42
- return "multi_metric_multi_track_scalars" in data_type_names
43
-
44
- @classmethod
45
- def generate_example_data(cls, example_data_path: Path, # This path is like .../trial_root/example_assets_for_ide
46
- data_sources_config: Optional[Dict[str, Dict[str, Any]]] = None
47
- ) -> Dict[str, Dict[str, Any]]:
48
- example_data_path.mkdir(parents=True, exist_ok=True)
49
- metrics_file_name = "example_scalar_metrics.toml"
50
- # 文件实际写入位置: example_data_path / metrics_file_name
51
- # Actual file write location: example_data_path / metrics_file_name
52
- metrics_file_full_path = example_data_path / metrics_file_name
53
-
54
- data_points = []
55
- for step in range(20):
56
- data_points.append({"global_step": step, "track": "train", "loss": 1.0 / (step + 1) + 0.1, "accuracy": 0.6 + step * 0.015})
57
- data_points.append({"global_step": step, "track": "validation", "loss": 1.0 / (step + 1) + 0.2, "accuracy": 0.55 + step * 0.01})
58
- if step % 5 == 0:
59
- data_points.append({"global_step": step, "track": "system", "learning_rate": 0.001 * (0.9**step)})
60
-
61
- try:
62
- with open(metrics_file_full_path, "wb") as f:
63
- tomli_w.dump({"metrics": data_points}, f)
64
- except Exception as e:
65
- st.error(f"生成示例数据失败: {e}")
66
- raise
67
-
68
- # *** 修正点 (FIXED POINT) ***
69
- # 返回的路径应该是相对于 trial_root_path 的。
70
- # The returned path should be relative to trial_root_path.
71
- # example_data_path 的父目录是 trial_root_path。
72
- # The parent of example_data_path is trial_root_path in the IDE's context.
73
- # 所以,相对路径是 example_data_path 的名称(即 "example_assets_for_ide")加上文件名。
74
- # So, the relative path is the name of example_data_path (i.e., "example_assets_for_ide") plus the filename.
75
- # path_relative_to_trial_root = example_data_path.name / metrics_file_name
76
- # 更稳健的方式是,如果 example_data_path 是绝对路径,而我们需要相对于 trial_root_path 的路径,
77
- # 并且我们知道 trial_root_path 是 example_data_path 的父目录(或更早的祖先)。
78
- # A more robust way, if example_data_path is absolute, and we need path relative to trial_root_path,
79
- # and we know trial_root_path is a parent (or earlier ancestor) of example_data_path.
80
- # 在IDE的上下文中,example_data_path = trial_root_path / "example_assets_for_ide"
81
- # In the IDE's context, example_data_path = trial_root_path / "example_assets_for_ide"
82
- # 所以,相对于 trial_root_path 的路径就是 "example_assets_for_ide" / metrics_file_name
83
- # So, path relative to trial_root_path is "example_assets_for_ide" / metrics_file_name
84
-
85
- path_for_manifest = Path(example_data_path.name) / metrics_file_name
86
-
87
- return {
88
- "main_metrics_source": {
89
- "asset_id": "example_metrics_dashboard_data",
90
- "data_type": "multi_metric_multi_track_scalars",
91
- "path": str(path_for_manifest), # 使用修正后的相对路径 Use the corrected relative path
92
- "display_name": "示例综合指标数据 (Example Comprehensive Metrics)"
93
- }
94
- }
95
-
96
- def load_data(self) -> None:
97
- data_asset_path = self._get_data_asset_path("main_metrics_source")
98
- if data_asset_path is None or not data_asset_path.exists():
99
- st.warning(f"组件 {self.component_instance_id}: 未找到数据源 'main_metrics_source' 或路径无效: {data_asset_path}")
100
- self._raw_data = pd.DataFrame()
101
- self._processed_metrics_data = {}
102
- return
103
-
104
- try:
105
- with open(data_asset_path, "rb") as f:
106
- data = tomli.load(f)
107
-
108
- metrics_list = data.get("metrics", [])
109
- if not metrics_list:
110
- self._raw_data = pd.DataFrame()
111
- else:
112
- self._raw_data = pd.DataFrame(metrics_list)
113
-
114
- self._process_raw_data()
115
-
116
- except Exception as e:
117
- st.error(f"组件 {self.component_instance_id}: 加载数据 '{data_asset_path}' 失败: {e}")
118
- self._raw_data = pd.DataFrame()
119
- self._processed_metrics_data = {}
120
-
121
- def _process_raw_data(self) -> None:
122
- self._processed_metrics_data = {}
123
- if self._raw_data is None or self._raw_data.empty:
124
- return
125
-
126
- if "global_step" not in self._raw_data.columns:
127
- st.warning(f"组件 {self.component_instance_id}: 数据缺少 'global_step' 列。")
128
- return
129
-
130
- potential_metric_cols = [
131
- col for col in self._raw_data.columns if col not in ["global_step", "track"]
132
- ]
133
-
134
- for metric_col_name in potential_metric_cols:
135
- metric_df_cols = ["global_step", metric_col_name]
136
- if "track" in self._raw_data.columns:
137
- metric_df_cols.append("track")
138
-
139
- # 确保所有需要的列都存在于self._raw_data中
140
- # Ensure all needed columns exist in self._raw_data
141
- if not all(col in self._raw_data.columns for col in metric_df_cols):
142
- # st.warning(f"组件 {self.component_instance_id}: 指标 '{metric_col_name}' 的原始数据缺少某些列: {metric_df_cols}")
143
- continue
144
-
145
- metric_df = self._raw_data[metric_df_cols].copy()
146
-
147
- if "track" not in metric_df.columns: # 如果原始数据就没有track列
148
- metric_df["track"] = "default"
149
-
150
- metric_df.rename(columns={metric_col_name: "value"}, inplace=True)
151
- metric_df["value"] = pd.to_numeric(metric_df["value"], errors='coerce')
152
- metric_df["global_step"] = pd.to_numeric(metric_df["global_step"], errors='coerce')
153
- metric_df.dropna(subset=["value", "global_step"], inplace=True)
154
-
155
- if not metric_df.empty:
156
- metric_df = metric_df.sort_values(by=["track", "global_step"]).reset_index(drop=True)
157
- self._processed_metrics_data[metric_col_name] = metric_df
158
-
159
- if self._all_available_steps is None and not self._raw_data.empty and "global_step" in self._raw_data.columns:
160
- valid_steps = pd.to_numeric(self._raw_data["global_step"], errors='coerce').dropna().astype(int).unique()
161
- self._all_available_steps = sorted(list(valid_steps))
162
-
163
-
164
- def _render_metric_summary(self, metric_name: str, metric_df: pd.DataFrame, target_step: Optional[int]):
165
- if target_step is None and self._all_available_steps: # 如果target_step是None,默认用最后一个step
166
- target_step = self._all_available_steps[-1]
167
- elif target_step is None or not self._all_available_steps:
168
- st.metric(label=f"{metric_name}", value="无可用步骤", delta=None)
169
- return
170
-
171
- actual_display_step = self._get_closest_available_step(target_step)
172
- if actual_display_step is None:
173
- st.metric(label=f"{metric_name}", value="无数据", delta=None)
174
- return
175
-
176
- all_tracks = sorted(list(metric_df["track"].unique()))
177
- num_tracks = len(all_tracks)
178
- if num_tracks == 0: return
179
-
180
- cols = st.columns(num_tracks) if num_tracks > 1 else [st.container()]
181
-
182
- for idx, track_name in enumerate(all_tracks):
183
- with cols[idx if num_tracks > 1 else 0]:
184
- track_data = metric_df[metric_df["track"] == track_name]
185
- if track_data.empty:
186
- st.metric(label=f"{metric_name} ({track_name})", value="无数据", delta=None)
187
- continue
188
-
189
- current_value = None
190
- delta_value = None
191
- step_for_current_value = actual_display_step
192
-
193
- current_step_data = track_data[track_data["global_step"] == step_for_current_value]
194
- if current_step_data.empty:
195
- prev_steps_for_track = track_data[track_data["global_step"] <= step_for_current_value]["global_step"]
196
- if not prev_steps_for_track.empty:
197
- step_for_current_value = prev_steps_for_track.max()
198
- current_step_data = track_data[track_data["global_step"] == step_for_current_value]
199
-
200
- if not current_step_data.empty:
201
- current_value = current_step_data["value"].iloc[0]
202
- prev_steps_for_delta = track_data[track_data["global_step"] < step_for_current_value]["global_step"]
203
- if not prev_steps_for_delta.empty:
204
- step_for_prev_value = prev_steps_for_delta.max()
205
- prev_value_data = track_data[track_data["global_step"] == step_for_prev_value]
206
- if not prev_value_data.empty:
207
- prev_value = prev_value_data["value"].iloc[0]
208
- delta_value = current_value - prev_value
209
-
210
- metric_label = f"{metric_name} ({track_name})"
211
- if step_for_current_value != target_step and current_value is not None: # 仅当找到的值的步骤与目标步骤不同时才显示
212
- metric_label += f" @S{int(step_for_current_value)}"
213
-
214
- st.metric(
215
- label=metric_label,
216
- value=f"{current_value:.4f}" if current_value is not None else "无数据",
217
- delta=f"{delta_value:.4f}" if delta_value is not None and current_value is not None else None,
218
- )
219
-
220
- def _render_plotly_chart(self, metric_name: str, metric_df: pd.DataFrame, chart_key: str):
221
- fig = go.Figure()
222
- all_tracks = sorted(list(metric_df["track"].unique()))
223
- colors = px.colors.qualitative.Plotly
224
-
225
- for i, track_name in enumerate(all_tracks):
226
- track_data = metric_df[metric_df["track"] == track_name]
227
- fig.add_trace(go.Scatter(
228
- x=track_data["global_step"],
229
- y=track_data["value"],
230
- mode="lines+markers",
231
- name=track_name,
232
- line=dict(color=colors[i % len(colors)]),
233
- marker=dict(size=6, color=colors[i % len(colors)], line=dict(width=1, color="white")),
234
- customdata=track_data[["global_step", "value", "track"]],
235
- hovertemplate="<b>%{customdata[2]}</b><br>Step: %{customdata[0]}<br>Value: %{customdata[1]:.4f}<extra></extra>"
236
- ))
237
-
238
- current_step_to_highlight = self._get_closest_available_step(self._current_global_step)
239
- if current_step_to_highlight is not None:
240
- fig.add_vline(
241
- x=current_step_to_highlight,
242
- line_width=1.5, line_dash="solid", line_color="firebrick", opacity=0.7
243
- )
244
-
245
- fig.update_layout(
246
- xaxis_title="Global Step",
247
- yaxis_title=metric_name,
248
- height=self.chart_height,
249
- margin=dict(l=10, r=10, t=30, b=10),
250
- showlegend=len(all_tracks) > 1,
251
- hovermode="closest",
252
- )
253
-
254
- event_data = st.plotly_chart(
255
- fig,
256
- use_container_width=True,
257
- key=chart_key,
258
- on_select="rerun"
259
- )
260
-
261
- current_selection = st.session_state.get(chart_key, {}).get("selection")
262
- if current_selection and current_selection.get("points"):
263
- clicked_point = current_selection["points"][0]
264
- if "customdata" in clicked_point and isinstance(clicked_point["customdata"], list):
265
- clicked_global_step = int(clicked_point["customdata"][0])
266
- last_clicked_step_key = f"{chart_key}_last_clicked_step"
267
- if st.session_state.get(last_clicked_step_key) != clicked_global_step:
268
- st.session_state[last_clicked_step_key] = clicked_global_step
269
- self._request_global_step_change(clicked_global_step)
270
- elif current_selection and not current_selection.get("points"):
271
- st.session_state[f"{chart_key}_last_clicked_step"] = None
272
-
273
-
274
- def render_settings_ui(self):
275
- st.markdown("##### 组件设置 (Component Settings)")
276
- new_charts_per_row = st.slider(
277
- "每行图表数 (Charts per row)", 1, 4, self.charts_per_row,
278
- key=f"{self.component_instance_id}_charts_per_row"
279
- )
280
- if new_charts_per_row != self.charts_per_row:
281
- self.charts_per_row = new_charts_per_row
282
- self.config["charts_per_row"] = new_charts_per_row
283
- self.save_component_config()
284
- st.rerun()
285
-
286
- new_chart_height = st.number_input(
287
- "图表高度 (Chart Height)", min_value=200, max_value=1000, step=50,
288
- value=self.chart_height,
289
- key=f"{self.component_instance_id}_chart_height"
290
- )
291
- if new_chart_height != self.chart_height:
292
- self.chart_height = new_chart_height
293
- self.config["chart_height"] = new_chart_height
294
- self.save_component_config()
295
- st.rerun()
296
-
297
-
298
- def render(self) -> None:
299
- if self._processed_metrics_data is None or not self._processed_metrics_data:
300
- if self._raw_data is None: # 尝试加载一次
301
- self.load_data()
302
- # 再次检查
303
- if self._processed_metrics_data is None or not self._processed_metrics_data:
304
- st.info(f"组件 {self.component_instance_id}: 没有处理好的指标数据可供显示。请先加载数据或生成示例数据。")
305
- st.caption("如果已生成示例数据但仍看到此消息,请检查数据路径和格式是否正确。If example data was generated but you still see this, check data paths and format.")
306
- return
307
-
308
- with st.expander("图表显示设置 (Chart Display Settings)", expanded=False):
309
- self.render_settings_ui()
310
-
311
- metric_names_to_display = sorted(list(self._processed_metrics_data.keys()))
312
- if not metric_names_to_display:
313
- st.caption("没有可显示的指标。No metrics to display.")
314
- return
315
-
316
- num_metrics = len(metric_names_to_display)
317
- cols_per_row = self.charts_per_row
318
-
319
- for i in range(0, num_metrics, cols_per_row):
320
- metric_chunk = metric_names_to_display[i : i + cols_per_row]
321
- actual_cols_for_this_row = len(metric_chunk)
322
- chart_cols = st.columns(actual_cols_for_this_row)
323
-
324
- for j, metric_name in enumerate(metric_chunk):
325
- with chart_cols[j]:
326
- metric_df = self._processed_metrics_data[metric_name]
327
-
328
- with st.container(border=True, height=self.chart_height + 200): # 增加容器高度以容纳metric和图表
329
- st.subheader(metric_name)
330
- self._render_metric_summary(metric_name, metric_df, self._current_global_step)
331
- st.markdown("---")
332
- chart_key = f"plotly_{self.component_instance_id}_{metric_name}"
333
- self._render_plotly_chart(metric_name, metric_df, chart_key)