Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces # Import spaces module for ZeroGPU | |
| from huggingface_hub import login | |
| import os | |
| from json_processor import JsonProcessor | |
| from dag_visualizer import DAGVisualizer | |
| import json | |
| # 1) Read Secrets | |
| hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
| if not hf_token: | |
| raise RuntimeError("β HUGGINGFACE_TOKEN not detected, please check Space Settings β Secrets") | |
| # 2) Login to ensure all subsequent from_pretrained calls have proper permissions | |
| login(hf_token) | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| import warnings | |
| import os | |
| warnings.filterwarnings("ignore") | |
| # Model configurations | |
| MODEL_CONFIGS = { | |
| "1B": { | |
| "name": "Dart-llm-model-1B", | |
| "base_model": "meta-llama/Llama-3.2-1B", | |
| "lora_model": "YongdongWang/llama-3.2-1b-lora-qlora-dart-llm" | |
| }, | |
| "3B": { | |
| "name": "Dart-llm-model-3B", | |
| "base_model": "meta-llama/Llama-3.2-3B", | |
| "lora_model": "YongdongWang/llama-3.2-3b-lora-qlora-dart-llm" | |
| }, | |
| "8B": { | |
| "name": "Dart-llm-model-8B", | |
| "base_model": "meta-llama/Llama-3.1-8B", | |
| "lora_model": "YongdongWang/llama-3.1-8b-lora-qlora-dart-llm" | |
| } | |
| } | |
| DEFAULT_MODEL = "1B" # Set 1B as default | |
| # Global variables to store model and tokenizer | |
| model = None | |
| tokenizer = None | |
| current_model_config = None | |
| model_loaded = False | |
| def load_model_and_tokenizer(selected_model=DEFAULT_MODEL): | |
| """Load tokenizer - executed on CPU""" | |
| global tokenizer, model_loaded, current_model_config | |
| if model_loaded and current_model_config == selected_model: | |
| return | |
| print(f"π Loading tokenizer for {MODEL_CONFIGS[selected_model]['name']}...") | |
| # Load tokenizer (on CPU) | |
| base_model = MODEL_CONFIGS[selected_model]["base_model"] | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| base_model, | |
| use_fast=False, | |
| trust_remote_code=True | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| current_model_config = selected_model | |
| model_loaded = True | |
| print("β Tokenizer loaded successfully!") | |
| # Request GPU for loading model at startup | |
| def load_model_on_gpu(selected_model=DEFAULT_MODEL): | |
| """Load model on GPU""" | |
| global model | |
| # If model is already loaded and it's the same model, return it | |
| if model is not None and current_model_config == selected_model: | |
| return model | |
| # Clear existing model if switching | |
| if model is not None: | |
| print("ποΈ Clearing existing model from GPU...") | |
| del model | |
| torch.cuda.empty_cache() | |
| model = None | |
| model_config = MODEL_CONFIGS[selected_model] | |
| print(f"π Loading {model_config['name']} on GPU...") | |
| try: | |
| # 4-bit quantization configuration | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| # Load base model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_config["base_model"], | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| use_safetensors=True | |
| ) | |
| # Load LoRA adapter | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| model_config["lora_model"], | |
| torch_dtype=torch.float16, | |
| use_safetensors=True | |
| ) | |
| model.eval() | |
| print(f"β {model_config['name']} loaded on GPU successfully!") | |
| return model | |
| except Exception as load_error: | |
| print(f"β Model loading failed: {load_error}") | |
| raise load_error | |
| def process_json_in_response(response): | |
| """Process and format JSON content in the response""" | |
| try: | |
| # Check if response contains JSON-like content | |
| if '{' in response and '}' in response: | |
| processor = JsonProcessor() | |
| # Try to process the response for JSON content | |
| processed_json = processor.process_response(response) | |
| if processed_json: | |
| # Format the JSON nicely | |
| formatted_json = json.dumps(processed_json, indent=2, ensure_ascii=False) | |
| # Replace the JSON part in the response | |
| import re | |
| json_pattern = r'\{.*\}' | |
| match = re.search(json_pattern, response, re.DOTALL) | |
| if match: | |
| # Replace the matched JSON with the formatted version | |
| response = response.replace(match.group(), formatted_json) | |
| return response | |
| except Exception: | |
| # If processing fails, return original response | |
| return response | |
| # GPU inference | |
| def generate_response_gpu(prompt, max_tokens=512, selected_model=DEFAULT_MODEL): | |
| """Generate response - executed on GPU""" | |
| global model | |
| # Ensure tokenizer is loaded | |
| if tokenizer is None or current_model_config != selected_model: | |
| load_model_and_tokenizer(selected_model) | |
| # Ensure model is loaded on GPU | |
| if model is None or current_model_config != selected_model: | |
| model = load_model_on_gpu(selected_model) | |
| if model is None: | |
| return "β Model failed to load. Please check the Space logs." | |
| try: | |
| formatted_prompt = ( | |
| "### Instruction:\n" | |
| f"{prompt.strip()}\n\n" | |
| "### Response:\n" | |
| ) | |
| # Encode input | |
| inputs = tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 | |
| ).to(model.device) | |
| # Generate response | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=False, | |
| temperature=None, | |
| top_p=None, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| repetition_penalty=1.1, | |
| early_stopping=True, | |
| no_repeat_ngram_size=3 | |
| ) | |
| # Decode output | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract generated part | |
| if "### Response:" in response: | |
| response = response.split("### Response:")[-1].strip() | |
| elif len(response) > len(formatted_prompt): | |
| response = response[len(formatted_prompt):].strip() | |
| # Process JSON if present in response | |
| response = process_json_in_response(response) | |
| return response if response else "β No response generated. Please try again with a different prompt." | |
| except Exception as generation_error: | |
| return f"β Generation Error: {str(generation_error)}" | |
| def create_dag_visualization(task_json_str): | |
| """Create DAG visualization from task JSON""" | |
| try: | |
| if not task_json_str.strip(): | |
| return None, "Please provide task JSON data" | |
| # Parse JSON | |
| task_data = json.loads(task_json_str) | |
| # Create DAG visualizer | |
| dag_visualizer = DAGVisualizer() | |
| # Generate visualization | |
| image_path = dag_visualizer.create_dag_visualization(task_data) | |
| if image_path: | |
| return image_path, "β DAG visualization created successfully!" | |
| else: | |
| return None, "β Failed to create DAG visualization" | |
| except json.JSONDecodeError as e: | |
| return None, f"β JSON Parse Error: {str(e)}" | |
| except Exception as e: | |
| return None, f"β DAG Creation Error: {str(e)}" | |
| def chat_interface(message, history, max_tokens, selected_model): | |
| """Chat interface - runs on CPU, calls GPU functions""" | |
| if not message.strip(): | |
| return history, "" | |
| # Initialize tokenizer (if needed) | |
| if tokenizer is None or current_model_config != selected_model: | |
| load_model_and_tokenizer(selected_model) | |
| try: | |
| # Call GPU function to generate response | |
| response = generate_response_gpu(message, max_tokens, selected_model) | |
| history.append((message, response)) | |
| return history, "" | |
| except Exception as chat_error: | |
| error_msg = f"β Chat Error: {str(chat_error)}" | |
| history.append((message, error_msg)) | |
| return history, "" | |
| # Load tokenizer at startup with default model | |
| load_model_and_tokenizer(DEFAULT_MODEL) | |
| # Create Gradio application | |
| with gr.Blocks( | |
| title="Robot Task Planning - DART-LLM Multi-Model", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px; | |
| margin: auto; | |
| } | |
| """ | |
| ) as app: | |
| gr.Markdown(""" | |
| # π€ DART-LLM Multi-Model - Robot Task Planning | |
| Choose from **three fine-tuned models** specialized for **robot task planning** using QLoRA technique: | |
| - **π Dart-llm-model-1B**: Ready for Jetson Nano deployment (870MB GGUF) | |
| - **βοΈ Dart-llm-model-3B**: Ready for Jetson Xavier NX deployment (1.9GB GGUF) | |
| - **π― Dart-llm-model-8B**: Ready for Jetson AGX Xavier/Orin deployment (4.6GB GGUF) | |
| **Capabilities**: Convert natural language robot commands into structured task sequences for excavators, dump trucks, and other construction robots. **Edge-ready for Jetson devices with DAG Visualization!** | |
| ## π§ Recommended for Jetson Deployment (GGUF Models) | |
| For optimal edge deployment performance, use these GGUF quantized models: | |
| - **[YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-1b-lora-qlora-dart-llm-gguf)** (870MB) - Jetson Nano/Orin Nano | |
| - **[YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.2-3b-lora-qlora-dart-llm-gguf)** (1.9GB) - Jetson Orin NX/AGX Orin | |
| - **[YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf](https://huggingface.co/YongdongWang/llama-3.1-8b-lora-qlora-dart-llm-gguf)** (4.6GB) - High-end Jetson AGX Orin | |
| π‘ **Deploy with**: Ollama, llama.cpp, or llama-cpp-python for efficient edge inference | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("π¬ Task Planning"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| label="Task Planning Results", | |
| height=500, | |
| show_label=True, | |
| container=True, | |
| bubble_full_width=False, | |
| show_copy_button=True | |
| ) | |
| msg = gr.Textbox( | |
| label="Robot Command", | |
| placeholder="Enter robot task command (e.g., 'Deploy Excavator 1 to Soil Area 1')...", | |
| lines=2, | |
| max_lines=5, | |
| show_label=True, | |
| container=True | |
| ) | |
| with gr.Row(): | |
| send_btn = gr.Button("π Generate Tasks", variant="primary", size="sm") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary", size="sm") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### βοΈ Generation Settings") | |
| model_selector = gr.Dropdown( | |
| choices=[(config["name"], key) for key, config in MODEL_CONFIGS.items()], | |
| value=DEFAULT_MODEL, | |
| label="Model Size", | |
| info="Select model for your Jetson device (1B = Nano, 3B = Xavier NX, 8B = AGX)", | |
| interactive=True | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=5000, | |
| value=512, | |
| step=10, | |
| label="Max Tokens", | |
| info="Maximum number of tokens to generate" | |
| ) | |
| gr.Markdown(""" | |
| ### π§ GGUF Models for Jetson Deployment | |
| **Recommended for edge deployment:** | |
| - **1B (870MB)**: Jetson Nano/Orin Nano (2GB RAM) | |
| - **3B (1.9GB)**: Jetson Orin NX/AGX Orin (4GB RAM) | |
| - **8B (4.6GB)**: High-end Jetson AGX Orin (8GB RAM) | |
| π‘ Use **Ollama** or **llama.cpp** for efficient inference | |
| """) | |
| with gr.Tab("π DAG Visualization"): | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| json_input = gr.Textbox( | |
| label="Task JSON Data", | |
| placeholder="Paste the generated task JSON here to create a DAG visualization...", | |
| lines=15, | |
| max_lines=25, | |
| show_label=True, | |
| container=True | |
| ) | |
| with gr.Row(): | |
| dag_btn = gr.Button("π¨ Generate DAG", variant="primary", size="sm") | |
| dag_clear_btn = gr.Button("ποΈ Clear", variant="secondary", size="sm") | |
| dag_status = gr.Textbox( | |
| label="Status", | |
| value="Ready to generate DAG visualization", | |
| interactive=False, | |
| show_label=True | |
| ) | |
| with gr.Column(scale=3): | |
| dag_output = gr.Image( | |
| label="Task Dependency Graph", | |
| show_label=True, | |
| container=True, | |
| height=600 | |
| ) | |
| gr.Markdown(""" | |
| ### π DAG Features | |
| - **Node Colors**: Red (Start), Orange (Intermediate), Purple (End) | |
| - **Arrows**: Show task dependencies | |
| - **Layout**: Hierarchical based on dependencies | |
| - **Details**: Task info boxes with robots and objects | |
| """) | |
| # Example conversations | |
| gr.Examples( | |
| examples=[ | |
| "Dump truck 1 goes to the puddle for inspection, after which all robots avoid the puddle.", | |
| "Drive the Excavator 1 to the obstacle, and perform excavation to clear the obstacle.", | |
| "Send Excavator 1 and Dump Truck 1 to the soil area; Excavator 1 will excavate and unload, followed by Dump Truck 1 proceeding to the puddle for unloading.", | |
| "Move Excavator 1 and Dump Truck 1 to soil area 2; Excavator 1 will excavate and unload, then Dump Truck 1 returns to the starting position to unload.", | |
| "Excavator 1 is guided to the obstacle to excavate and unload to clear the obstacle, then excavator 1 and dump truck 1 are moved to the soil area, and the excavator excavates and unloads. Finally, dump truck 1 unloads the soil into the puddle.", | |
| "Excavator 1 goes to the obstacle to excavate and unload to clear the obstacle. Once the obstacle is cleared, mobilize all available robots to proceed to the puddle area for inspection.", | |
| ], | |
| inputs=msg, | |
| label="π‘ Example Operator Commands" | |
| ) | |
| # Event handling | |
| msg.submit( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, model_selector], | |
| outputs=[chatbot, msg] | |
| ) | |
| send_btn.click( | |
| chat_interface, | |
| inputs=[msg, chatbot, max_tokens, model_selector], | |
| outputs=[chatbot, msg] | |
| ) | |
| clear_btn.click( | |
| lambda: ([], ""), | |
| outputs=[chatbot, msg] | |
| ) | |
| # DAG visualization event handlers | |
| dag_btn.click( | |
| create_dag_visualization, | |
| inputs=[json_input], | |
| outputs=[dag_output, dag_status] | |
| ) | |
| dag_clear_btn.click( | |
| lambda: ("", None, "Ready to generate DAG visualization"), | |
| outputs=[json_input, dag_output, dag_status] | |
| ) | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |