Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| import pickle | |
| from pathlib import Path | |
| from datetime import datetime | |
| import threading | |
| import glob | |
| from collections import Counter | |
| import struct | |
| class SimpleTokenizer: | |
| """A simple tokenizer for faster startup""" | |
| def __init__(self): | |
| self.vocab = {} | |
| self.inverse_vocab = {} | |
| self.vocab_size = 0 | |
| self.pad_token = "<pad>" | |
| self.pad_token_id = 0 | |
| self.eos_token = "<eos>" | |
| self.eos_token_id = 1 | |
| self.unk_token = "<unk>" | |
| self.unk_token_id = 2 | |
| # Start with basic tokens | |
| self.add_token(self.pad_token) # ID 0 | |
| self.add_token(self.eos_token) # ID 1 | |
| self.add_token(self.unk_token) # ID 2 | |
| def add_token(self, token): | |
| if token not in self.vocab: | |
| self.vocab[token] = self.vocab_size | |
| self.inverse_vocab[self.vocab_size] = token | |
| self.vocab_size += 1 | |
| return True | |
| return False | |
| def build_vocab_from_texts(self, texts, max_vocab_size=10000): | |
| """Build vocabulary from all training texts""" | |
| print("Building vocabulary from training data...") | |
| # Count all tokens | |
| token_counter = Counter() | |
| for text in texts: | |
| tokens = text.split() | |
| token_counter.update(tokens) | |
| # Add most frequent tokens to vocabulary | |
| for token, _ in token_counter.most_common(max_vocab_size - self.vocab_size): | |
| self.add_token(token) | |
| print(f"Vocabulary built with {self.vocab_size} tokens") | |
| def tokenize(self, text): | |
| # Simple word-level tokenization | |
| tokens = text.split() | |
| token_ids = [] | |
| for token in tokens: | |
| if token in self.vocab: | |
| token_ids.append(self.vocab[token]) | |
| else: | |
| token_ids.append(self.unk_token_id) # Use UNK token for out-of-vocab words | |
| return token_ids | |
| def encode(self, text, max_length=None, padding=False, truncation=False): | |
| token_ids = self.tokenize(text) | |
| if truncation and max_length and len(token_ids) > max_length: | |
| token_ids = token_ids[:max_length] | |
| if padding and max_length and len(token_ids) < max_length: | |
| token_ids = token_ids + [self.pad_token_id] * (max_length - len(token_ids)) | |
| return token_ids | |
| def decode(self, token_ids): | |
| # Remove padding tokens for cleaner output | |
| filtered_ids = [id for id in token_ids if id != self.pad_token_id] | |
| return " ".join([self.inverse_vocab.get(id, self.unk_token) for id in filtered_ids]) | |
| class TextDataset(Dataset): | |
| def __init__(self, texts, tokenizer, max_length=512): | |
| self.tokenizer = tokenizer | |
| self.texts = texts | |
| self.max_length = max_length | |
| # Filter out empty texts | |
| self.texts = [text for text in texts if text.strip()] | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, idx): | |
| text = self.texts[idx] | |
| # Ensure text is not empty | |
| if not text.strip(): | |
| text = " " # Use space for empty text | |
| token_ids = self.tokenizer.encode( | |
| text, | |
| max_length=self.max_length, | |
| padding=True, | |
| truncation=True | |
| ) | |
| # Convert to tensor and ensure all IDs are within valid range | |
| token_ids = [min(id, self.tokenizer.vocab_size - 1) for id in token_ids] | |
| return { | |
| 'input_ids': torch.tensor(token_ids, dtype=torch.long), | |
| 'labels': torch.tensor(token_ids, dtype=torch.long) | |
| } | |
| class SimpleGPT(nn.Module): | |
| """A simplified GPT-like model for faster training""" | |
| def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8, max_seq_len=512): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.vocab_size = vocab_size | |
| self.max_seq_len = max_seq_len | |
| # Token and position embeddings | |
| self.token_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) # padding_idx=0 for pad token | |
| self.position_embedding = nn.Embedding(max_seq_len, d_model) | |
| # Transformer layers | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=n_heads, | |
| dim_feedforward=d_model * 4, | |
| batch_first=True, | |
| dropout=0.1 | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) | |
| # Output layer with dropout for regularization | |
| self.dropout = nn.Dropout(0.1) | |
| self.output_layer = nn.Linear(d_model, vocab_size) | |
| # Initialize weights properly | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def forward(self, input_ids, labels=None): | |
| batch_size, seq_len = input_ids.shape | |
| # Ensure all token IDs are within valid range | |
| input_ids = torch.clamp(input_ids, 0, self.vocab_size - 1) | |
| # Create token embeddings | |
| token_embeds = self.token_embedding(input_ids) | |
| # Create position embeddings | |
| positions = torch.arange(seq_len, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_len) | |
| position_embeds = self.position_embedding(positions) | |
| # Combine embeddings | |
| x = token_embeds + position_embeds | |
| # Create attention mask (ignore padding tokens) | |
| attention_mask = (input_ids != 0).float() | |
| # Transformer with attention mask | |
| x = self.transformer(x, src_key_padding_mask=attention_mask == 0) | |
| # Apply dropout | |
| x = self.dropout(x) | |
| # Output | |
| logits = self.output_layer(x) | |
| # Calculate loss if labels provided | |
| loss = None | |
| if labels is not None: | |
| # Ensure labels are within valid range | |
| labels = torch.clamp(labels, 0, self.vocab_size - 1) | |
| # Create loss mask to ignore padding tokens | |
| loss_mask = (labels != 0).float() | |
| loss_fn = nn.CrossEntropyLoss(ignore_index=0, reduction='none') # ignore padding | |
| losses = loss_fn(logits.view(-1, self.vocab_size), labels.view(-1)) | |
| loss = (losses * loss_mask.view(-1)).sum() / loss_mask.sum() | |
| return {'logits': logits, 'loss': loss} | |
| class AITrainerApp: | |
| def __init__(self): | |
| # Use simple tokenizer for faster startup | |
| self.tokenizer = SimpleTokenizer() | |
| self.model = None | |
| self.training_data = [] | |
| # Default model configuration | |
| self.model_config = { | |
| "d_model": 512, | |
| "n_layers": 6, | |
| "n_heads": 8, | |
| "max_seq_len": 512 | |
| } | |
| # Training control | |
| self.training_thread = None | |
| self.stop_training_flag = False | |
| self.training_status = "Ready - Load training data to begin" | |
| self.output_log = "Training output will appear here...\n" | |
| def get_device(self, device_type="auto"): | |
| """Get the selected device based on user choice""" | |
| if device_type == "auto": | |
| return torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| elif device_type == "cuda": | |
| if torch.cuda.is_available(): | |
| return torch.device('cuda') | |
| else: | |
| return torch.device('cpu') | |
| else: | |
| return torch.device('cpu') | |
| def log_output(self, message): | |
| """Add message to output log""" | |
| self.output_log += message + "\n" | |
| return self.output_log | |
| def verify_model_file(self, file_path): | |
| """Verify if a model file is valid before loading""" | |
| try: | |
| # Simple file checks | |
| if not os.path.exists(file_path): | |
| return False, "File does not exist" | |
| if os.path.getsize(file_path) < 1024: # Less than 1KB | |
| return False, "File is too small to be a valid model" | |
| return True, "File appears valid" | |
| except Exception as e: | |
| return False, f"Error verifying file: {str(e)}" | |
| def load_training_files(self, files): | |
| """Load training files from provided file objects""" | |
| if not files: | |
| return "No files selected", self.output_log | |
| total_texts = [] | |
| for file_info in files: | |
| try: | |
| # Read the content from the file object | |
| content = file_info.read().decode('utf-8') | |
| # Split into smaller chunks if needed | |
| chunks = self.split_into_chunks(content, 1000) | |
| total_texts.extend(chunks) | |
| self.output_log = self.log_output(f"Loaded {len(chunks)} chunks from {file_info.name}") | |
| except Exception as e: | |
| error_msg = f"Error reading {file_info.name}: {str(e)}" | |
| self.output_log = self.log_output(error_msg) | |
| return error_msg, self.output_log | |
| self.training_data.extend(total_texts) | |
| # Build vocabulary from all training texts | |
| self.tokenizer.build_vocab_from_texts(self.training_data, max_vocab_size=10000) | |
| status_msg = f"Loaded {len(total_texts)} text chunks from {len(files)} files" | |
| self.output_log = self.log_output(status_msg) | |
| self.output_log = self.log_output(f"Vocabulary size: {self.tokenizer.vocab_size}") | |
| return status_msg, self.output_log | |
| def split_into_chunks(self, text, chunk_size): | |
| words = text.split() | |
| chunks = [] | |
| for i in range(0, len(words), chunk_size): | |
| chunk = ' '.join(words[i:i+chunk_size]) | |
| chunks.append(chunk) | |
| return chunks | |
| def view_training_data(self): | |
| if not self.training_data: | |
| return "No training data loaded" | |
| preview = "" | |
| for i, text in enumerate(self.training_data[:50]): # Show first 50 chunks | |
| preview += f"Chunk {i+1}:\n{text}\n\n{'='*50}\n\n" | |
| return preview | |
| def start_training(self, d_model, n_layers, n_heads, batch_size, learning_rate, epochs, device_type): | |
| if not self.training_data: | |
| error_msg = "Error: No training data loaded!" | |
| self.output_log = self.log_output(error_msg) | |
| return error_msg, self.output_log, gr.update(interactive=True) | |
| self.stop_training_flag = False | |
| self.training_status = "Training started..." | |
| self.output_log = self.log_output("Starting training...") | |
| # Update model config from UI | |
| self.model_config.update({ | |
| "d_model": int(d_model), | |
| "n_layers": int(n_layers), | |
| "n_heads": int(n_heads) | |
| }) | |
| # Start training in separate thread | |
| self.training_thread = threading.Thread( | |
| target=self.train_model, | |
| args=(int(batch_size), float(learning_rate), int(epochs), device_type) | |
| ) | |
| self.training_thread.daemon = True | |
| self.training_thread.start() | |
| return "Training started...", self.output_log, gr.update(interactive=False) | |
| def stop_training(self): | |
| self.stop_training_flag = True | |
| self.training_status = "Stopping training..." | |
| self.output_log = self.log_output("Stopping training...") | |
| return "Stopping training...", self.output_log, gr.update(interactive=True) | |
| def train_model(self, batch_size, learning_rate, epochs, device_type): | |
| try: | |
| # Create dataset and dataloader | |
| dataset = TextDataset(self.training_data, self.tokenizer) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True | |
| ) | |
| # Initialize model | |
| self.model = SimpleGPT( | |
| vocab_size=self.tokenizer.vocab_size, | |
| d_model=self.model_config["d_model"], | |
| n_layers=self.model_config["n_layers"], | |
| n_heads=self.model_config["n_heads"], | |
| max_seq_len=self.model_config["max_seq_len"] | |
| ) | |
| # Setup optimizer | |
| optimizer = optim.AdamW( | |
| self.model.parameters(), | |
| lr=learning_rate | |
| ) | |
| # Training loop | |
| device = self.get_device(device_type) | |
| self.model.to(device) | |
| self.output_log = self.log_output(f"Using device: {device}") | |
| for epoch in range(epochs): | |
| if self.stop_training_flag: | |
| break | |
| self.model.train() | |
| total_loss = 0 | |
| total_batches = 0 | |
| for batch_idx, batch in enumerate(dataloader): | |
| if self.stop_training_flag: | |
| break | |
| optimizer.zero_grad() | |
| input_ids = batch['input_ids'].to(device) | |
| labels = batch['labels'].to(device) | |
| # Debug: Check for invalid token IDs | |
| max_id = input_ids.max().item() | |
| if max_id >= self.tokenizer.vocab_size: | |
| self.output_log = self.log_output(f"Warning: Found token ID {max_id} but vocab size is {self.tokenizer.vocab_size}") | |
| # Clamp values to valid range | |
| input_ids = torch.clamp(input_ids, 0, self.tokenizer.vocab_size - 1) | |
| labels = torch.clamp(labels, 0, self.tokenizer.vocab_size - 1) | |
| outputs = self.model(input_ids=input_ids, labels=labels) | |
| loss = outputs['loss'] | |
| if torch.isnan(loss) or torch.isinf(loss): | |
| self.output_log = self.log_output("Warning: NaN or Inf loss detected, skipping batch") | |
| continue | |
| loss.backward() | |
| # Gradient clipping to prevent explosions | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| total_loss += loss.item() | |
| total_batches += 1 | |
| if batch_idx % 10 == 0: | |
| status_msg = f"Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}" | |
| self.training_status = status_msg | |
| if batch_idx % 50 == 0: # Log less frequently to avoid UI slowdown | |
| self.output_log = self.log_output(status_msg) | |
| if total_batches > 0: | |
| avg_loss = total_loss / total_batches | |
| epoch_msg = f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}" | |
| self.training_status = epoch_msg | |
| self.output_log = self.log_output(epoch_msg) | |
| if not self.stop_training_flag: | |
| completion_msg = "Training completed successfully!" | |
| self.training_status = completion_msg | |
| self.output_log = self.log_output(completion_msg) | |
| except Exception as e: | |
| error_msg = f"Training error: {str(e)}" | |
| self.training_status = error_msg | |
| self.output_log = self.log_output(error_msg) | |
| import traceback | |
| self.output_log = self.log_output(traceback.format_exc()) | |
| finally: | |
| self.stop_training_flag = False | |
| # Re-enable the start training button | |
| return gr.update(interactive=True) | |
| def save_model(self, file_path): | |
| if self.model is None: | |
| self.output_log = self.log_output("Error: No model to save!") | |
| return "Error: No model to save!", self.output_log | |
| try: | |
| torch.save({ | |
| 'model_state_dict': self.model.state_dict(), | |
| 'tokenizer': self.tokenizer, | |
| 'config': self.model_config, | |
| 'training_data_info': { | |
| 'num_chunks': len(self.training_data), | |
| 'vocab_size': self.tokenizer.vocab_size | |
| } | |
| }, file_path) | |
| success_msg = f"Model saved to {file_path}" | |
| self.training_status = success_msg | |
| self.output_log = self.log_output(success_msg) | |
| return success_msg, self.output_log | |
| except Exception as e: | |
| error_msg = f"Error saving model: {str(e)}" | |
| self.output_log = self.log_output(error_msg) | |
| return error_msg, self.output_log | |
| def load_model(self, file_path): | |
| if not file_path: | |
| return "No file selected", self.output_log | |
| try: | |
| checkpoint = torch.load(file_path, map_location='cpu') | |
| # Recreate the model architecture | |
| self.model_config = checkpoint['config'] | |
| self.model = SimpleGPT( | |
| vocab_size=checkpoint['tokenizer'].vocab_size, | |
| d_model=self.model_config["d_model"], | |
| n_layers=self.model_config["n_layers"], | |
| n_heads=self.model_config["n_heads"], | |
| max_seq_len=self.model_config["max_seq_len"] | |
| ) | |
| # Load weights | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| # Load tokenizer | |
| self.tokenizer = checkpoint['tokenizer'] | |
| success_msg = f"Model loaded from {file_path}" | |
| self.training_status = success_msg | |
| self.output_log = self.log_output(success_msg) | |
| return success_msg, self.output_log, str(self.model_config['d_model']), str(self.model_config['n_layers']), str(self.model_config['n_heads']) | |
| except Exception as e: | |
| error_msg = f"Error loading model: {str(e)}" | |
| self.output_log = self.log_output(error_msg) | |
| return error_msg, self.output_log, gr.update(), gr.update(), gr.update() | |
| # Create the app instance | |
| app = AITrainerApp() | |
| # Create Gradio interface | |
| with gr.Blocks(title="AI Text Generation Trainer") as demo: | |
| gr.Markdown("# AI Text Generation Trainer") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Controls") | |
| # Data management | |
| gr.Markdown("### Data Management") | |
| file_input = gr.File(file_count="multiple", label="Training Files") | |
| load_btn = gr.Button("Load Text Files") | |
| view_data_btn = gr.Button("View Training Data") | |
| data_preview = gr.Textbox(label="Training Data Preview", lines=10, interactive=False) | |
| # Device selection | |
| gr.Markdown("### Device Selection") | |
| device_type = gr.Radio( | |
| choices=["auto", "cpu", "cuda"], | |
| value="auto", | |
| label="Processing Device" | |
| ) | |
| device_info = gr.Textbox( | |
| label="Device Info", | |
| value=f"GPU available: {'Yes' if torch.cuda.is_available() else 'No'}", | |
| interactive=False | |
| ) | |
| # Model configuration | |
| gr.Markdown("### Model Configuration") | |
| d_model = gr.Number(value=512, label="Embedding Size") | |
| n_layers = gr.Number(value=6, label="Number of Layers") | |
| n_heads = gr.Number(value=8, label="Number of Heads") | |
| # Training parameters | |
| gr.Markdown("### Training Parameters") | |
| batch_size = gr.Number(value=4, label="Batch Size") | |
| learning_rate = gr.Number(value=0.001, label="Learning Rate") | |
| epochs = gr.Number(value=3, label="Epochs") | |
| # Training controls | |
| gr.Markdown("### Training Control") | |
| start_btn = gr.Button("Start Training", variant="primary") | |
| stop_btn = gr.Button("Stop Training") | |
| # Export buttons | |
| gr.Markdown("### Export Model") | |
| save_path = gr.Textbox(label="Save Path", value="model.pth") | |
| save_btn = gr.Button("Save Model") | |
| load_path = gr.Textbox(label="Load Path", value="model.pth") | |
| load_btn = gr.Button("Load Model") | |
| with gr.Column(scale=2): | |
| gr.Markdown("## Status & Output") | |
| status = gr.Textbox(label="Status", value=app.training_status, interactive=False) | |
| output = gr.Textbox(label="Output Log", value=app.output_log, lines=20, interactive=False) | |
| # Define event handlers | |
| load_btn.click( | |
| app.load_training_files, | |
| inputs=[file_input], | |
| outputs=[status, output] | |
| ) | |
| view_data_btn.click( | |
| app.view_training_data, | |
| inputs=[], | |
| outputs=[data_preview] | |
| ) | |
| start_btn.click( | |
| app.start_training, | |
| inputs=[d_model, n_layers, n_heads, batch_size, learning_rate, epochs, device_type], | |
| outputs=[status, output, start_btn] | |
| ) | |
| stop_btn.click( | |
| app.stop_training, | |
| inputs=[], | |
| outputs=[status, output, start_btn] | |
| ) | |
| save_btn.click( | |
| app.save_model, | |
| inputs=[save_path], | |
| outputs=[status, output] | |
| ) | |
| load_btn.click( | |
| app.load_model, | |
| inputs=[load_path], | |
| outputs=[status, output, d_model, n_layers, n_heads] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |