---
license: apache-2.0
language:
- en
datasets:
- PleIAs/SYNTH
library_name: transformers
tags:
- pytorch
- causal-lm
- neuroblast
- experimental
- tpu
---
# NeuroBLAST-V3-SYNTH-EC-150000-JAX
⚠️ EXPERIMENTAL EARLY JAX CHECKPOINT ⚠️
This is an **Early JAX Checkpoint (EC)** of the **NeuroBLAST V3** architecture, a novel hybrid model designed with a biologically inspired "cortical" structure.
This specific checkpoint (`150k` steps) represents the "pre-decay" phase of training. It has been trained on short contexts with a high learning rate and is intended for architectural evaluation and research purposes.
## Model Details
* **Architecture:** NeuroBLAST V3 (Custom Hybrid Architecture)
* **Checkpoint Step:** 150,000
* **Parameters:** 596,728,320
* **Num layers:** 72
* **Sensory layers:** 24
* **Associative layers:** 32
* **Motor layers:** 16
* **Hidden size:** 512
* **Vocab size:** 65538
* **Intermediate size:** 3072
* **Num attention heads:** 16
* **Num kv heads:** 8
* **Head dim:** 128
* **Tie word embeddings**: False
### Architecture Highlights
NeuroBLAST differs from standard Transformers by utilizing a three-stage cortical design:
1. **Sensory Cortex:** Hybrid layers alternating between Attention and Dilated Causal 2D Convolutions.
2. **Associative Cortex:** Hybrid layers with alternating RoPE usage.
3. **Motor Cortex:** Pure Attention layers.
4. **Deep Residual Bridges:** Long-range residual connections injecting the original embeddings (and their negations) between cortical stages to improve signal propagation.

## Training Details
This model is currently being trained using the Google TPU Research Cloud (TRC).
* **Dataset:** [PleIAs/SYNTH](https://huggingface.co/datasets/PleIAs/SYNTH)
* **Tokens Processed:** ~118 Billion
* **Hardware:** TPUv4-16
* **Training Time:** ~8 Days
* **Effective Batch Size:** 1024
* **Context Length:** 768 tokens (Current phase)
* **Learning rate:** 4e-3
* **Weight decay:** 0.0
* **Optimizer:** AdamW
* **Precision:** BFloat16
* **Current State:** Pre-decay phase (No weight decay applied yet).

### Roadmap
This checkpoint marks the end of the initial warmup/learning phase. The next steps in training are:
1. Significantly extending the context length.
2. Lowering the learning rate.
3. Introducing weight decay for convergence.
## Usage
**Note:** You must use `trust_remote_code=True` as this model utilizes custom modeling code (`modeling_neuroblast.py`).
```python
import argparse
import jax
import jax.numpy as jnp
from transformers import AutoTokenizer
from neuroblast3_jax.modeling_neuroblast_jax import NeuroBLASTForCausalLM as NeuroBLASTForCausalLMJax
def generate_text(model, tokenizer, text, max_new_tokens=50, temperature=0.7, top_k=50):
inputs = tokenizer(f"user\n{text}<|im_end|><|im_start|>assistant\n", return_tensors="np")
original_input_ids = inputs["input_ids"]
batch_size, prompt_len = original_input_ids.shape
total_len = prompt_len + max_new_tokens
# Pad input_ids to total_len
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
input_ids = jnp.full((batch_size, total_len), pad_id, dtype=jnp.int32)
input_ids = input_ids.at[:, :prompt_len].set(original_input_ids)
attention_mask = jnp.ones((batch_size, total_len), dtype=jnp.int32)
params = model.params
@jax.jit
def model_step(params, input_ids, attention_mask, rng):
outputs = model(input_ids=input_ids, attention_mask=attention_mask, params=params, train=False)
return outputs.logits
rng = jax.random.PRNGKey(0)
print("Generating...")
current_len = prompt_len
printed_len = 0
for i in range(max_new_tokens):
rng, step_rng = jax.random.split(rng)
# Run model
logits = model_step(params, input_ids, attention_mask, step_rng)
# Get logits for the last valid token (current_len - 1)
next_token_logits = logits[:, current_len - 1, :]
# Sampling
scaled_logits = next_token_logits / temperature
next_token = jax.random.categorical(step_rng, scaled_logits, axis=-1)
# Update input_ids
# We need to update the next position
input_ids = input_ids.at[:, current_len].set(next_token)
current_len += 1
# Streaming output
valid_ids = input_ids[0, :current_len]
current_text = tokenizer.decode(valid_ids, skip_special_tokens=False)
if i == 0:
pass
new_text = current_text[printed_len:]
if new_text:
print(new_text, end="", flush=True)
printed_len += len(new_text)
# Check EOS
if next_token[0] == tokenizer.eos_token_id:
break
valid_ids = input_ids[0, :current_len]
return tokenizer.decode(valid_ids, skip_special_tokens=False)
checkpoint = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000-JAX"
print(f"Loading model from {checkpoint}...")
tokenizer = AutoTokenizer.from_pretrained(
checkpoint,
use_fast=True,
trust_remote_code=True,
)
print(f"Available devices: {jax.devices()}")
model = NeuroBLASTForCausalLMJax.from_pretrained(
checkpoint,
dtype=jnp.bfloat16,
trust_remote_code=True,
is_decoder=True,
)
generated_text = generate_text(model, tokenizer, 'what is hypertension?', 128)
print("\nGenerated Text:")
print("-" * 20)
print(generated_text)
print("-" * 20)
```
You can find the **PYTORCH** implementation here [mkurman/NeuroBLAST-V3-SYNTH-EC-150000](https://huggingface.co/mkurman/NeuroBLAST-V3-SYNTH-EC-150000).
```python
import torch
from transformers import AutoTokenizer, TextStreamer, AutoModelForCausalLM
model_id = "mkurman/NeuroBLAST-V3-SYNTH-EC-150000"
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load the model with custom code trust
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map='cuda',
trust_remote_code=True
).eval()
streamer = TextStreamer(
tokenizer, skip_prompt=False, decode_kwargs={"skip_special_tokens": False}
)
# Prepare input
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": "what is hypertension?"}],
tokenize=True,
return_tensors="pt",
add_generation_prompt=True
)
print(f"Input IDs: {input_ids}")
# Generate
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids.to(model.device),
max_new_tokens=128,
streamer=streamer,
use_cache=True,
# Important: Keep repetition_penalty at 1.0 for this early checkpoint
repetition_penalty=1.0,
)
```
## Acknowledgments
This model was trained using Cloud TPUs provided by Google's TPU Research Cloud (TRC) program.
Special thanks to [Pierre-Carl Langlais](https://huggingface.co/Pclanglais) and the [PleIAs](https://huggingface.co/PleIAs) team for the high-quality SYNTH dataset.