AmnaHassan commited on
Commit
ad43496
·
verified ·
1 Parent(s): 0f85d48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -102
app.py CHANGED
@@ -1,23 +1,41 @@
 
1
  import numpy as np
 
 
2
  import matplotlib.pyplot as plt
3
- from mpl_toolkits.mplot3d import Axes3D
4
- import gradio as gr
5
  from io import BytesIO
6
  import base64
7
- import random
 
8
 
9
- # ---------- ENVIRONMENT SETUP ----------
10
  GRID_SIZE = 8
11
  ACTIONS = ['up', 'down', 'left', 'right']
12
 
13
  class CarEnvironment:
14
- def __init__(self):
15
- self.reset()
 
16
 
17
- def reset(self):
 
 
 
18
  self.car = (0, 0)
19
  self.goal = (GRID_SIZE - 1, GRID_SIZE - 1)
20
- self.obstacles = [(random.randint(1, GRID_SIZE-2), random.randint(1, GRID_SIZE-2)) for _ in range(10)]
 
 
 
 
 
 
 
 
 
 
 
21
  self.steps = 0
22
  return self.car
23
 
@@ -36,10 +54,10 @@ class CarEnvironment:
36
  self.steps += 1
37
 
38
  if new_pos in self.obstacles:
39
- reward = -5
40
  done = True
41
  elif new_pos == self.goal:
42
- reward = 10
43
  done = True
44
  else:
45
  reward = -0.1
@@ -49,118 +67,225 @@ class CarEnvironment:
49
  return new_pos, reward, done
50
 
51
  # ---------- Q-LEARNING ----------
52
- def q_learning(env, episodes=100):
53
  q_table = np.zeros((GRID_SIZE, GRID_SIZE, len(ACTIONS)))
54
- alpha, gamma, epsilon = 0.5, 0.9, 0.3
55
-
56
- for _ in range(episodes):
57
  state = env.reset()
 
58
  done = False
59
- while not done:
60
- if random.uniform(0, 1) < epsilon:
 
61
  action = random.choice(ACTIONS)
62
  else:
63
  action = ACTIONS[np.argmax(q_table[state[0], state[1]])]
 
64
  next_state, reward, done = env.step(action)
65
- old_value = q_table[state[0], state[1], ACTIONS.index(action)]
66
- next_max = np.max(q_table[next_state[0], next_state[1]])
67
- q_table[state[0], state[1], ACTIONS.index(action)] = old_value + alpha * (reward + gamma * next_max - old_value)
 
 
 
68
  state = next_state
69
- return q_table
 
 
70
 
71
- def simulate_path(env, q_table):
 
72
  state = env.reset()
73
  path = [state]
74
  done = False
75
- while not done and len(path) < 100:
 
76
  action = ACTIONS[np.argmax(q_table[state[0], state[1]])]
77
  next_state, _, done = env.step(action)
78
  path.append(next_state)
79
  state = next_state
 
80
  return path
81
 
82
- # ---------- VISUALIZATION ----------
83
- def render_scene(path, obstacles, goal, view="3D"):
84
- fig = plt.figure(figsize=(6, 6))
85
-
86
- if view == "3D":
87
- ax = fig.add_subplot(111, projection="3d")
88
- X, Y = np.meshgrid(np.arange(GRID_SIZE), np.arange(GRID_SIZE))
89
- Z = np.zeros_like(X)
90
- ax.plot_surface(X, Y, Z, color='gray', alpha=0.3)
91
-
92
- # Obstacles
93
- for (x, y) in obstacles:
94
- ax.bar3d(y, x, 0, 1, 1, 2, color='red', alpha=0.8)
95
-
96
- # Path
97
- for (x, y) in path:
98
- ax.bar3d(y, x, 0, 1, 1, 0.3, color='dodgerblue', alpha=0.6)
99
-
100
- # Car (last position)
101
- car_x, car_y = path[-1]
102
- ax.bar3d(car_y, car_x, 0, 1, 1, 1, color='yellow', alpha=1.0)
103
-
104
- # Goal
105
- ax.bar3d(goal[1], goal[0], 0, 1, 1, 0.1, color='lime', alpha=1.0)
106
-
107
- ax.set_xlim(0, GRID_SIZE)
108
- ax.set_ylim(0, GRID_SIZE)
109
- ax.set_zlim(0, 3)
110
- ax.set_title("3D Autonomous Car Navigation", fontsize=14, color="white", pad=20)
111
- ax.set_facecolor("#1a1a1a")
112
-
113
- else: # 2D View
114
- ax = fig.add_subplot(111)
115
- ax.set_xlim(0, GRID_SIZE)
116
- ax.set_ylim(0, GRID_SIZE)
117
- ax.set_xticks(range(GRID_SIZE))
118
- ax.set_yticks(range(GRID_SIZE))
119
- ax.grid(True, linestyle='--', alpha=0.4)
120
- ax.set_facecolor("#121212")
121
-
122
- for (x, y) in obstacles:
123
- ax.add_patch(plt.Rectangle((y, GRID_SIZE-1-x), 1, 1, color="crimson"))
124
-
125
- for (x, y) in path:
126
- ax.add_patch(plt.Rectangle((y, GRID_SIZE-1-x), 1, 1, color="dodgerblue", alpha=0.4))
127
-
128
- ax.add_patch(plt.Rectangle((goal[1], GRID_SIZE-1-goal[0]), 1, 1, color="lime"))
129
- car_x, car_y = path[-1]
130
- ax.add_patch(plt.Rectangle((car_y, GRID_SIZE-1-car_x), 1, 1, color="gold"))
131
-
132
- ax.set_title("2D Autonomous Car Path", color="white", fontsize=14)
133
-
134
- plt.tight_layout()
135
  buf = BytesIO()
136
- plt.savefig(buf, format="png", bbox_inches="tight", facecolor="#1a1a1a")
137
  plt.close(fig)
138
  buf.seek(0)
139
- img_data = base64.b64encode(buf.read()).decode("utf-8")
140
- return f"<img src='data:image/png;base64,{img_data}' style='border-radius:12px; width:100%;'/>"
141
-
142
- # ---------- SIMULATION FUNCTION ----------
143
- def run_simulation(view):
144
- env = CarEnvironment()
145
- q_table = q_learning(env)
146
- path = simulate_path(env, q_table)
147
- return render_scene(path, env.obstacles, env.goal, view)
148
-
149
- # ---------- GRADIO INTERFACE ----------
150
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="violet")) as demo:
151
- gr.Markdown(
152
- """
153
- <div style='text-align:center; background: linear-gradient(90deg, #0d47a1, #311b92);
154
- color:white; padding:20px; border-radius:12px;'>
155
- <h1>🚗 AI Car Navigation Simulator</h1>
156
- <p>Watch a Reinforcement Learning Agent learn to drive towards its goal while avoiding obstacles — in 2D or 3D!</p>
157
- </div>
158
- """
159
- )
160
- view_mode = gr.Radio(["2D", "3D"], value="3D", label="Choose View Mode")
161
- run_btn = gr.Button("▶️ Run Simulation", variant="primary")
162
- output = gr.HTML(label="Simulation Output")
163
-
164
- run_btn.click(fn=run_simulation, inputs=view_mode, outputs=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  demo.launch()
 
1
+ import random
2
  import numpy as np
3
+ import matplotlib
4
+ matplotlib.use("Agg") # for headless servers like Hugging Face Spaces
5
  import matplotlib.pyplot as plt
6
+ from mpl_toolkits.mplot3d import Axes3D # noqa: F401
 
7
  from io import BytesIO
8
  import base64
9
+ from PIL import Image
10
+ import gradio as gr
11
 
12
+ # ---------- ENVIRONMENT ----------
13
  GRID_SIZE = 8
14
  ACTIONS = ['up', 'down', 'left', 'right']
15
 
16
  class CarEnvironment:
17
+ def __init__(self, obstacles=None, seed=None):
18
+ self.seed = seed
19
+ self.reset(obstacles)
20
 
21
+ def reset(self, obstacles=None):
22
+ if self.seed is not None:
23
+ random.seed(self.seed)
24
+ np.random.seed(self.seed)
25
  self.car = (0, 0)
26
  self.goal = (GRID_SIZE - 1, GRID_SIZE - 1)
27
+ # deterministic obstacles if provided, else random but reproducible with seed
28
+ if obstacles:
29
+ self.obstacles = obstacles
30
+ else:
31
+ # ensure obstacles don't overlap start/goal
32
+ obs = set()
33
+ while len(obs) < int(GRID_SIZE * 1.25):
34
+ x = random.randint(1, GRID_SIZE - 2)
35
+ y = random.randint(1, GRID_SIZE - 2)
36
+ if (x, y) not in [(0,0), self.goal]:
37
+ obs.add((x,y))
38
+ self.obstacles = list(obs)
39
  self.steps = 0
40
  return self.car
41
 
 
54
  self.steps += 1
55
 
56
  if new_pos in self.obstacles:
57
+ reward = -5.0
58
  done = True
59
  elif new_pos == self.goal:
60
+ reward = 10.0
61
  done = True
62
  else:
63
  reward = -0.1
 
67
  return new_pos, reward, done
68
 
69
  # ---------- Q-LEARNING ----------
70
+ def q_learning(env, episodes=500, alpha=0.7, gamma=0.95, epsilon=0.1):
71
  q_table = np.zeros((GRID_SIZE, GRID_SIZE, len(ACTIONS)))
72
+ rewards_per_episode = []
73
+ for ep in range(episodes):
 
74
  state = env.reset()
75
+ total = 0.0
76
  done = False
77
+ steps = 0
78
+ while not done and steps < 400:
79
+ if random.random() < epsilon:
80
  action = random.choice(ACTIONS)
81
  else:
82
  action = ACTIONS[np.argmax(q_table[state[0], state[1]])]
83
+
84
  next_state, reward, done = env.step(action)
85
+ ai = ACTIONS.index(action)
86
+ old = q_table[state[0], state[1], ai]
87
+ # Temporal difference update (Q-learning)
88
+ q_table[state[0], state[1], ai] = old + alpha * (reward + gamma * np.max(q_table[next_state[0], next_state[1]]) - old)
89
+
90
+ total += reward
91
  state = next_state
92
+ steps += 1
93
+ rewards_per_episode.append(total)
94
+ return q_table, rewards_per_episode
95
 
96
+ # ---------- SIMULATION / PATH ----------
97
+ def simulate_path(env, q_table, max_steps=200):
98
  state = env.reset()
99
  path = [state]
100
  done = False
101
+ steps = 0
102
+ while not done and steps < max_steps:
103
  action = ACTIONS[np.argmax(q_table[state[0], state[1]])]
104
  next_state, _, done = env.step(action)
105
  path.append(next_state)
106
  state = next_state
107
+ steps += 1
108
  return path
109
 
110
+ # ---------- RENDER HELPERS ----------
111
+ def fig_to_pil(fig, facecolor=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  buf = BytesIO()
113
+ fig.savefig(buf, format="png", bbox_inches='tight', facecolor=facecolor)
114
  plt.close(fig)
115
  buf.seek(0)
116
+ return Image.open(buf).convert("RGBA")
117
+
118
+ def render_frame_3d(path, obstacles, goal, elev=30, azim=45):
119
+ fig = plt.figure(figsize=(6,6), facecolor="#111111")
120
+ ax = fig.add_subplot(111, projection="3d")
121
+ # floor
122
+ X, Y = np.meshgrid(np.arange(GRID_SIZE+1), np.arange(GRID_SIZE+1))
123
+ Z = np.zeros_like(X)
124
+ ax.plot_surface(X, Y, Z, color='gray', alpha=0.08)
125
+ # obstacles
126
+ for (x,y) in obstacles:
127
+ ax.bar3d(y, x, 0, 1, 1, 1.8, color='red', alpha=0.9)
128
+ # path bars (lower)
129
+ for (x,y) in path:
130
+ ax.bar3d(y, x, 0, 1, 1, 0.25, color='deepskyblue', alpha=0.6)
131
+ # car (top)
132
+ car_x, car_y = path[-1]
133
+ ax.bar3d(car_y, car_x, 0, 1, 1, 0.9, color='gold', alpha=1.0, edgecolor='k')
134
+ # goal
135
+ ax.bar3d(goal[1], goal[0], 0, 1, 1, 0.12, color='lime', alpha=1.0)
136
+ ax.set_xlim( -0.5, GRID_SIZE - 0.5)
137
+ ax.set_ylim( -0.5, GRID_SIZE - 0.5)
138
+ ax.set_zlim(0, 3)
139
+ ax.view_init(elev=elev, azim=azim)
140
+ ax.set_xticks([])
141
+ ax.set_yticks([])
142
+ ax.set_zticks([])
143
+ ax.set_facecolor("#111111")
144
+ return fig_to_pil(fig, facecolor="#111111")
145
+
146
+ def render_frame_2d(path, obstacles, goal):
147
+ fig = plt.figure(figsize=(6,6), facecolor="#111111")
148
+ ax = fig.add_subplot(111)
149
+ ax.set_xlim(0, GRID_SIZE)
150
+ ax.set_ylim(0, GRID_SIZE)
151
+ ax.set_xticks(np.arange(0.5, GRID_SIZE, 1))
152
+ ax.set_yticks(np.arange(0.5, GRID_SIZE, 1))
153
+ ax.set_xticklabels([])
154
+ ax.set_yticklabels([])
155
+ ax.grid(True, color='#2a2a2a', linestyle='--', linewidth=1)
156
+ ax.set_facecolor("#0f0f0f")
157
+ # draw obstacles
158
+ for (x,y) in obstacles:
159
+ ax.add_patch(plt.Rectangle((y, GRID_SIZE-1-x), 1, 1, color='crimson'))
160
+ # draw path
161
+ for (x,y) in path:
162
+ ax.add_patch(plt.Rectangle((y, GRID_SIZE-1-x), 1, 1, color='deepskyblue', alpha=0.6))
163
+ # car
164
+ car_x, car_y = path[-1]
165
+ ax.add_patch(plt.Rectangle((car_y, GRID_SIZE-1-car_x), 1, 1, color='gold'))
166
+ # goal
167
+ ax.add_patch(plt.Rectangle((goal[1], GRID_SIZE-1-goal[0]), 1, 1, color='lime'))
168
+ ax.set_title("2D View", color="white")
169
+ return fig_to_pil(fig, facecolor="#111111")
170
+
171
+ def frames_to_gif(frames, duration_ms=300):
172
+ # frames: list of PIL.Image
173
+ # duration_ms per frame
174
+ buf = BytesIO()
175
+ # convert to P mode for smaller size & better GIF compatibility
176
+ frames[0].save(buf, format='GIF', save_all=True, append_images=frames[1:],
177
+ duration=duration_ms, loop=0, disposal=2, optimize=True)
178
+ buf.seek(0)
179
+ return buf.read()
180
+
181
+ def img_bytes_to_datauri(img_bytes, mime='image/gif'):
182
+ return "data:{};base64,{}".format(mime, base64.b64encode(img_bytes).decode('utf-8'))
183
+
184
+ def plot_reward_curve(rewards):
185
+ fig = plt.figure(figsize=(6,3), facecolor="#111111")
186
+ ax = fig.add_subplot(111)
187
+ ax.plot(rewards, linewidth=1.5)
188
+ ax.set_xlabel("Episode", color="white")
189
+ ax.set_ylabel("Total Reward", color="white")
190
+ ax.set_facecolor("#111111")
191
+ ax.tick_params(colors="white")
192
+ fig.tight_layout()
193
+ return fig_to_pil(fig, facecolor="#111111")
194
+
195
+ # ---------- GRADIO CALLBACKS & STATE ----------
196
+ def train_agent(episodes, alpha, gamma, epsilon, seed):
197
+ # create reproducible environment for training
198
+ env = CarEnvironment(seed=seed)
199
+ q_table, rewards = q_learning(env, episodes=episodes, alpha=alpha, gamma=gamma, epsilon=epsilon)
200
+ reward_img = plot_reward_curve(rewards)
201
+ # store q_table and obstacles/goal for later simulation
202
+ metadata = {
203
+ "q_table": q_table,
204
+ "obstacles": env.obstacles.copy(),
205
+ "goal": env.goal,
206
+ "seed": seed
207
+ }
208
+ # return metadata as state, and reward image as data URI
209
+ buf = BytesIO()
210
+ reward_img.save(buf, format="PNG")
211
+ buf.seek(0)
212
+ reward_datauri = "data:image/png;base64," + base64.b64encode(buf.read()).decode("utf-8")
213
+ return metadata, reward_datauri, f"Trained for {episodes} episodes. Reward (last): {round(rewards[-1], 2)}"
214
+
215
+ def start_drive(view_mode, speed_ms, rotate_camera, state):
216
+ # state should contain q_table and map details
217
+ if not state:
218
+ return None, "No trained agent found. Please train the agent first."
219
+ q_table = state["q_table"]
220
+ obstacles = state["obstacles"]
221
+ goal = state["goal"]
222
+ seed = state.get("seed", None)
223
+ env = CarEnvironment(obstacles=obstacles, seed=seed)
224
+ path = simulate_path(env, q_table, max_steps=200)
225
+ # Create frames
226
+ frames = []
227
+ # small camera motion parameters
228
+ base_elev = 30
229
+ base_azim = 45
230
+ for i in range(1, len(path)+1):
231
+ subpath = path[:i]
232
+ if view_mode == "3D":
233
+ elev = base_elev + (rotate_camera * (i/len(path)) * 10)
234
+ azim = base_azim + (rotate_camera * (i/len(path)) * 40)
235
+ frame = render_frame_3d(subpath, obstacles, goal, elev=elev, azim=azim)
236
+ else:
237
+ frame = render_frame_2d(subpath, obstacles, goal)
238
+ frames.append(frame)
239
+ # hold on final frame longer
240
+ if len(frames) >= 1:
241
+ frames.append(frames[-1])
242
+ gif_bytes = frames_to_gif(frames, duration_ms=max(50, int(speed_ms)))
243
+ datauri = img_bytes_to_datauri(gif_bytes, mime='image/gif')
244
+ info = f"Drive simulated: {len(path)-1} steps. View: {view_mode}. Speed: {speed_ms} ms/frame."
245
+ return datauri, info
246
+
247
+ # ---------- GRADIO LAYOUT ----------
248
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="violet")) as demo:
249
+ gr.Markdown("""<div style="text-align:center; padding:18px; border-radius:10px;
250
+ background: linear-gradient(90deg,#0d47a1,#4a148c); color:white">
251
+ <h2>🚗 AI Car Navigation Lab — Animated 2D / 3D Demo</h2>
252
+ <p style="margin:0">Train a tabular Q-learning agent, visualize training, then run an animated drive (GIF)</p>
253
+ </div>""")
254
+
255
+ with gr.Row():
256
+ with gr.Column(scale=1):
257
+ gr.Markdown("### ▶ Training Controls")
258
+ episodes = gr.Slider(100, 3000, step=100, value=600, label="Training Episodes")
259
+ alpha = gr.Slider(0.05, 1.0, step=0.05, value=0.7, label="Learning rate α")
260
+ gamma = gr.Slider(0.1, 0.999, step=0.01, value=0.95, label="Discount factor γ")
261
+ epsilon = gr.Slider(0.0, 1.0, step=0.05, value=0.15, label="Exploration ε")
262
+ seed = gr.Number(value=42, label="Random seed (reproducible map)", precision=0)
263
+ train_btn = gr.Button("🧠 Train Agent", variant="primary")
264
+ reward_output = gr.Image(label="Reward Curve (training)", interactive=False)
265
+ train_status = gr.Textbox(label="Training status", interactive=False)
266
+ with gr.Column(scale=1):
267
+ gr.Markdown("### ▶ Simulation & Animation")
268
+ view_mode = gr.Radio(["3D", "2D"], value="3D", label="Visualization Mode")
269
+ speed_slider = gr.Slider(50, 1000, step=10, value=250, label="Animation Speed (ms per frame)")
270
+ rotate_cam = gr.Slider(0, 1, step=0.1, value=0.6, label="Subtle camera rotation (3D only)")
271
+ drive_btn = gr.Button("▶ Start Drive", variant="secondary")
272
+ gif_output = gr.HTML(label="Animated Drive (GIF)")
273
+ drive_info = gr.Textbox(label="Simulation info", interactive=False)
274
+
275
+ # hidden state to hold the trained model & environment metadata
276
+ state = gr.State(value=None)
277
+
278
+ # wire up callbacks
279
+ train_btn.click(fn=train_agent, inputs=[episodes, alpha, gamma, epsilon, seed],
280
+ outputs=[state, reward_output, train_status])
281
+
282
+ drive_btn.click(fn=start_drive, inputs=[view_mode, speed_slider, rotate_cam, state],
283
+ outputs=[gif_output, drive_info])
284
+
285
+ # helpful footer
286
+ gr.Markdown("""
287
+ **Notes:** The agent is tabular Q-learning. Use the sliders to tune hyperparameters.
288
+ The animation is a GIF generated on-the-fly; download it from the GIF image if you want a clip.
289
+ """)
290
 
291
  demo.launch()