Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- jax
|
| 5 |
+
- rl
|
| 6 |
+
- jumanji
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
# CVRP-V1
|
| 10 |
+
This model is trained on the Jumanji CVRP environment
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
**Developed by:** InstaDeep
|
| 14 |
+
|
| 15 |
+
### Model Sources
|
| 16 |
+
|
| 17 |
+
<!-- Provide the basic links for the model. -->
|
| 18 |
+
|
| 19 |
+
- **Repository:** [Jumanji](https://github.com/instadeepai/jumanji)
|
| 20 |
+
- **Paper:** TBD
|
| 21 |
+
|
| 22 |
+
### How to use
|
| 23 |
+
|
| 24 |
+
[Notebook](#)
|
| 25 |
+
|
| 26 |
+
Go to the jumanji repo for the primary model and requirements. Clone the repo and navigate to the root directory.
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
pip install -e .
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
Below is an example script for loading and running the Jumanji model
|
| 33 |
+
|
| 34 |
+
```python
|
| 35 |
+
import pickle
|
| 36 |
+
import joblib
|
| 37 |
+
|
| 38 |
+
import jax
|
| 39 |
+
from hydra import compose, initialize
|
| 40 |
+
from huggingface_hub import hf_hub_download
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
from jumanji.training.setup_train import setup_agent, setup_env
|
| 44 |
+
from jumanji.training.utils import first_from_device
|
| 45 |
+
|
| 46 |
+
# initialise the config
|
| 47 |
+
with initialize(version_base=None, config_path="jumanji/training/configs"):
|
| 48 |
+
cfg = compose(config_name="config.yaml", overrides=["env=cvrp", "agent=a2c"])
|
| 49 |
+
|
| 50 |
+
# get model state from HF
|
| 51 |
+
REPO_ID = "InstaDeepAI/jumanji-cvrp-v1-a2c-benchmark"
|
| 52 |
+
FILENAME = "CVRP-v1_training_state"
|
| 53 |
+
|
| 54 |
+
model_weights = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
|
| 55 |
+
|
| 56 |
+
with open(model_weights,"rb") as f:
|
| 57 |
+
training_state = pickle.load(f)
|
| 58 |
+
|
| 59 |
+
params = first_from_device(training_state.params_state.params)
|
| 60 |
+
env = setup_env(cfg).unwrapped
|
| 61 |
+
agent = setup_agent(cfg, env)
|
| 62 |
+
policy = jax.jit(agent.make_policy(params.actor, stochastic = False))
|
| 63 |
+
|
| 64 |
+
# rollout a few episodes
|
| 65 |
+
NUM_EPISODES = 10
|
| 66 |
+
|
| 67 |
+
states = []
|
| 68 |
+
key = jax.random.PRNGKey(cfg.seed)
|
| 69 |
+
for episode in range(NUM_EPISODES):
|
| 70 |
+
key, reset_key = jax.random.split(key)
|
| 71 |
+
state, timestep = jax.jit(env.reset)(reset_key)
|
| 72 |
+
while not timestep.last():
|
| 73 |
+
key, action_key = jax.random.split(key)
|
| 74 |
+
observation = jax.tree_util.tree_map(lambda x: x[None], timestep.observation)
|
| 75 |
+
action, _ = policy(observation, action_key)
|
| 76 |
+
state, timestep = jax.jit(env.step)(state, action.squeeze(axis=0))
|
| 77 |
+
states.append(state)
|
| 78 |
+
# Freeze the terminal frame to pause the GIF.
|
| 79 |
+
for _ in range(10):
|
| 80 |
+
states.append(state)
|
| 81 |
+
|
| 82 |
+
# animate a GIF
|
| 83 |
+
env.animate(states, interval=150).save("./binpack.gif")
|
| 84 |
+
|
| 85 |
+
# save PNG
|
| 86 |
+
import matplotlib.pyplot as plt
|
| 87 |
+
%matplotlib inline
|
| 88 |
+
env.render(states[117])
|
| 89 |
+
plt.savefig("connector.png", dpi=300)
|
| 90 |
+
|
| 91 |
+
```
|