--- license: apache-2.0 language: - en datasets: - PleIAs/SYNTH library_name: transformers tags: - pytorch - causal-lm - neuroblast - experimental - tpu --- # NeuroBLAST-V3-SYNTH-EC-150000-JAX
⚠️ EXPERIMENTAL EARLY JAX CHECKPOINT ⚠️
This is an **Early JAX Checkpoint (EC)** of the **NeuroBLAST V3** architecture, a novel hybrid model designed with a biologically inspired "cortical" structure. This specific checkpoint (`150k` steps) represents the "pre-decay" phase of training. It has been trained on short contexts with a high learning rate and is intended for architectural evaluation and research purposes. ## Model Details * **Architecture:** NeuroBLAST V3 (Custom Hybrid Architecture) * **Checkpoint Step:** 150,000 * **Parameters:** 596,728,320 * **Num layers:** 72 * **Sensory layers:** 24 * **Associative layers:** 32 * **Motor layers:** 16 * **Hidden size:** 512 * **Vocab size:** 65538 * **Intermediate size:** 3072 * **Num attention heads:** 16 * **Num kv heads:** 8 * **Head dim:** 128 * **Tie word embeddings**: False ### Architecture Highlights NeuroBLAST differs from standard Transformers by utilizing a three-stage cortical design: 1. **Sensory Cortex:** Hybrid layers alternating between Attention and Dilated Causal 2D Convolutions. 2. **Associative Cortex:** Hybrid layers with alternating RoPE usage. 3. **Motor Cortex:** Pure Attention layers. 4. **Deep Residual Bridges:** Long-range residual connections injecting the original embeddings (and their negations) between cortical stages to improve signal propagation. ![architecture](arch.jpeg) ## Training Details This model is currently being trained using the Google TPU Research Cloud (TRC). * **Dataset:** [PleIAs/SYNTH](https://huggingface.co/datasets/PleIAs/SYNTH) * **Tokens Processed:** ~118 Billion * **Hardware:** TPUv4-16 * **Training Time:** ~8 Days * **Effective Batch Size:** 1024 * **Context Length:** 768 tokens (Current phase) * **Learning rate:** 4e-3 * **Weight decay:** 0.0 * **Optimizer:** AdamW * **Precision:** BFloat16 * **Current State:** Pre-decay phase (No weight decay applied yet). ![eval_loss](eval_loss.png) ### Roadmap This checkpoint marks the end of the initial warmup/learning phase. The next steps in training are: 1. Significantly extending the context length. 2. Lowering the learning rate. 3. Introducing weight decay for convergence. ## Usage **Note:** You must use `trust_remote_code=True` as this model utilizes custom modeling code (`modeling_neuroblast.py`). ```python import argparse import jax import jax.numpy as jnp from transformers import AutoTokenizer from neuroblast3_jax.modeling_neuroblast_jax import NeuroBLASTForCausalLM as NeuroBLASTForCausalLMJax def generate_text(model, tokenizer, text, max_new_tokens=50, temperature=0.7, top_k=50): inputs = tokenizer(f"user\n{text}<|im_end|><|im_start|>assistant\n", return_tensors="np") original_input_ids = inputs["input_ids"] batch_size, prompt_len = original_input_ids.shape total_len = prompt_len + max_new_tokens # Pad input_ids to total_len pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 input_ids = jnp.full((batch_size, total_len), pad_id, dtype=jnp.int32) input_ids = input_ids.at[:, :prompt_len].set(original_input_ids) attention_mask = jnp.ones((batch_size, total_len), dtype=jnp.int32) params = model.params @jax.jit def model_step(params, input_ids, attention_mask, rng): outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False) return outputs.logits rng = jax.random.PRNGKey(0) print("Generating...") current_len = prompt_len printed_len = 0 for i in range(max_new_tokens): rng, step_rng = jax.random.split(rng) # Run model logits = model_step(params, input_ids, attention_mask, step_rng) # Get logits for the last valid token (current_len - 1) next_token_logits = logits[:, current_len - 1, :] # Sampling scaled_logits = next_token_logits / temperature next_token = jax.random.categorical(step_rng, scaled_logits, axis=-1) # Update input_ids # We need to update the next position input_ids = input_ids.at[:, current_len].set(next_token) current_len += 1 # Streaming output valid_ids = input_ids[0, :current_len] current_text = tokenizer.decode(valid_ids, skip_special_tokens=False) if i == 0: pass new_text = current_text[printed_len:] if new_text: print(new_text, end="", flush=True) printed_len += len(new_text) # Check EOS if next_token[0] == tokenizer.eos_token_id: break valid_ids = input_ids[0, :current_len] return tokenizer.decode(valid_ids, skip_special_tokens=False) checkpoint = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000-JAX" print(f"Loading model from {checkpoint}...") tokenizer = AutoTokenizer.from_pretrained( checkpoint, use_fast=True, trust_remote_code=True, ) print(f"Available devices: {jax.devices()}") model = NeuroBLASTForCausalLMJax.from_pretrained( checkpoint, dtype=jnp.bfloat16, trust_remote_code=True, is_decoder=True, ) generated_text = generate_text(model, tokenizer, 'what is hypertension?', 128) print("\nGenerated Text:") print("-" * 20) print(generated_text) print("-" * 20) ``` You can find the **PYTORCH** implementation here [mkurman/NeuroBLAST-V3-SYNTH-EC-150000](https://huggingface.co/mkurman/NeuroBLAST-V3-SYNTH-EC-150000). ```python import torch from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM model_id = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000" # Load the tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id) # Load the model with custom code trust model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map='cuda', trust_remote_code=True ).eval() streamer = TextStreamer( tokenizer, skip_prompt=False, decode_kwargs={"skip_special_tokens": False} ) # Prepare input input_ids = tokenizer.apply_chat_template( [{"role": "user", "content": "what is hypertension?"}], tokenize=True, return_tensors="pt", add_generation_prompt=True ) print(f"Input IDs: {input_ids}") # Generate with torch.no_grad(): outputs = model.generate( input_ids=input_ids.to(model.device), max_new_tokens=128, streamer=streamer, use_cache=True, # Important: Keep repetition_penalty at 1.0 for this early checkpoint repetition_penalty=1.0, ) ``` ## Acknowledgments This model was trained using Cloud TPUs provided by Google's TPU Research Cloud (TRC) program. Special thanks to [Pierre-Carl Langlais](https://huggingface.co/Pclanglais) and the [PleIAs](https://huggingface.co/PleIAs) team for the high-quality SYNTH dataset.