put more context
Browse files- app.py +15 -9
- requirements.txt +2 -1
- utils.py +11 -1
app.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
-
import os
|
| 2 |
import wandb
|
| 3 |
import streamlit as st
|
| 4 |
import streamlit.components.v1 as components
|
| 5 |
|
| 6 |
-
from utils import train
|
| 7 |
|
| 8 |
project = "st"
|
| 9 |
entity = "capecape"
|
|
@@ -13,23 +13,29 @@ HEIGHT = 720
|
|
| 13 |
def get_project(api, name, entity=None):
|
| 14 |
return api.project(name, entity=entity).to_html(height=HEIGHT)
|
| 15 |
|
| 16 |
-
st.title("
|
|
|
|
|
|
|
| 17 |
|
| 18 |
# Sidebar
|
| 19 |
sb = st.sidebar
|
| 20 |
sb.title("Train your model")
|
| 21 |
# wandb_token = sb.text_input("paste your wandb Api key if you want: https://wandb.ai/authorize", type="password")
|
| 22 |
|
|
|
|
| 23 |
|
| 24 |
# wandb.login(key=wandb_token)
|
| 25 |
wandb.login(anonymous="must")
|
| 26 |
api = wandb.Api()
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
# render wandb dashboard
|
| 29 |
components.html(get_project(api, project, entity), height=HEIGHT)
|
| 30 |
|
| 31 |
# run params
|
| 32 |
-
runs =
|
| 33 |
epochs = sb.number_input('Number of epochs:', min_value=1, max_value=1000, value=100)
|
| 34 |
|
| 35 |
|
|
@@ -39,10 +45,10 @@ We will execute a simple training loop
|
|
| 39 |
```python
|
| 40 |
wandb.init(project="st", ...)
|
| 41 |
for i in range(epochs):
|
| 42 |
-
acc = 1 - 2 ** -i - random()
|
| 43 |
-
loss = 2 ** -i + random()
|
| 44 |
-
wandb.log({"acc": acc,
|
| 45 |
-
|
| 46 |
```
|
| 47 |
"""
|
| 48 |
|
|
@@ -54,4 +60,4 @@ if sb.button("Run Example"):
|
|
| 54 |
print("Running training")
|
| 55 |
for i in range(runs):
|
| 56 |
my_bar = sb.progress(0)
|
| 57 |
-
train(project=project, entity=entity, epochs=epochs, bar=my_bar)
|
|
|
|
| 1 |
+
import os, random
|
| 2 |
import wandb
|
| 3 |
import streamlit as st
|
| 4 |
import streamlit.components.v1 as components
|
| 5 |
|
| 6 |
+
from utils import train, WORDS
|
| 7 |
|
| 8 |
project = "st"
|
| 9 |
entity = "capecape"
|
|
|
|
| 13 |
def get_project(api, name, entity=None):
|
| 14 |
return api.project(name, entity=entity).to_html(height=HEIGHT)
|
| 15 |
|
| 16 |
+
st.title("The wandb Dashboard π")
|
| 17 |
+
|
| 18 |
+
run_name = "-".join(random.choices(WORDS, k=2)) + f"-{random.randint(0,100)}"
|
| 19 |
|
| 20 |
# Sidebar
|
| 21 |
sb = st.sidebar
|
| 22 |
sb.title("Train your model")
|
| 23 |
# wandb_token = sb.text_input("paste your wandb Api key if you want: https://wandb.ai/authorize", type="password")
|
| 24 |
|
| 25 |
+
run_name = sb.text_input("Run name", run_name, disabled=True)
|
| 26 |
|
| 27 |
# wandb.login(key=wandb_token)
|
| 28 |
wandb.login(anonymous="must")
|
| 29 |
api = wandb.Api()
|
| 30 |
|
| 31 |
+
st.success(f"You should see a new run named **{run_name}**, it\'ll have a green circle while it\'s still active")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
# render wandb dashboard
|
| 35 |
components.html(get_project(api, project, entity), height=HEIGHT)
|
| 36 |
|
| 37 |
# run params
|
| 38 |
+
runs = 1
|
| 39 |
epochs = sb.number_input('Number of epochs:', min_value=1, max_value=1000, value=100)
|
| 40 |
|
| 41 |
|
|
|
|
| 45 |
```python
|
| 46 |
wandb.init(project="st", ...)
|
| 47 |
for i in range(epochs):
|
| 48 |
+
acc = 1 - 2 ** -i - random()
|
| 49 |
+
loss = 2 ** -i + random()
|
| 50 |
+
wandb.log({"acc": acc,
|
| 51 |
+
"loss": loss})
|
| 52 |
```
|
| 53 |
"""
|
| 54 |
|
|
|
|
| 60 |
print("Running training")
|
| 61 |
for i in range(runs):
|
| 62 |
my_bar = sb.progress(0)
|
| 63 |
+
train(name=run_name, project=project, entity=entity, epochs=epochs, bar=my_bar)
|
requirements.txt
CHANGED
|
@@ -1,2 +1,3 @@
|
|
| 1 |
wandb
|
| 2 |
-
streamlit
|
|
|
|
|
|
| 1 |
wandb
|
| 2 |
+
streamlit
|
| 3 |
+
requests
|
utils.py
CHANGED
|
@@ -1,11 +1,21 @@
|
|
| 1 |
import random, time
|
|
|
|
| 2 |
|
| 3 |
import wandb
|
| 4 |
|
| 5 |
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
run = wandb.init(
|
| 8 |
# Set the project where this run will be logged
|
|
|
|
| 9 |
project=project,
|
| 10 |
entity=entity,
|
| 11 |
# Track hyperparameters and run metadata
|
|
|
|
| 1 |
import random, time
|
| 2 |
+
import requests
|
| 3 |
|
| 4 |
import wandb
|
| 5 |
|
| 6 |
|
| 7 |
+
|
| 8 |
+
word_site = "https://www.mit.edu/~ecprice/wordlist.10000"
|
| 9 |
+
|
| 10 |
+
response = requests.get(word_site)
|
| 11 |
+
WORDS = [w.decode("UTF-8") for w in response.content.splitlines()]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train(name, project="st", entity=None, epochs=10, bar=None):
|
| 16 |
run = wandb.init(
|
| 17 |
# Set the project where this run will be logged
|
| 18 |
+
name=name,
|
| 19 |
project=project,
|
| 20 |
entity=entity,
|
| 21 |
# Track hyperparameters and run metadata
|