abhinavv3 commited on
Commit
0cac660
·
1 Parent(s): f58ea49

Fixed some issues and bugs .Finished trail training succesfully

Browse files
Readme.md CHANGED
Binary files a/Readme.md and b/Readme.md differ
 
configs/config.json CHANGED
@@ -5,13 +5,14 @@
5
  "n_layer": 12,
6
  "n_head": 12,
7
  "n_embd": 768,
8
- "n_kv_head": 4
 
9
  },
10
  "training": {
11
  "max_steps": 19073,
12
  "log_dir": "log",
13
  "total_batch_size": 524288,
14
- "B": 8,
15
  "T": 1024,
16
  "max_lr": 0.0006,
17
  "min_lr": 0.00006,
 
5
  "n_layer": 12,
6
  "n_head": 12,
7
  "n_embd": 768,
8
+ "n_kv_head": 4,
9
+ "max_knn_memories": 81920
10
  },
11
  "training": {
12
  "max_steps": 19073,
13
  "log_dir": "log",
14
  "total_batch_size": 524288,
15
+ "B": 64,
16
  "T": 1024,
17
  "max_lr": 0.0006,
18
  "min_lr": 0.00006,
log/log.txt CHANGED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 0 val 10.9481 shard_0
2
+ 0 train 10.947327 shard_0
3
+ 1 train 10.917969 shard_0
model_core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (162 Bytes). View file
 
model_core/__pycache__/attention.cpython-311.pyc CHANGED
Binary files a/model_core/__pycache__/attention.cpython-311.pyc and b/model_core/__pycache__/attention.cpython-311.pyc differ
 
model_core/__pycache__/model.cpython-311.pyc CHANGED
Binary files a/model_core/__pycache__/model.cpython-311.pyc and b/model_core/__pycache__/model.cpython-311.pyc differ
 
model_core/__pycache__/training.cpython-311.pyc CHANGED
Binary files a/model_core/__pycache__/training.cpython-311.pyc and b/model_core/__pycache__/training.cpython-311.pyc differ
 
model_core/__pycache__/training.cpython-312.pyc ADDED
Binary file (11.7 kB). View file
 
model_core/attention.py CHANGED
@@ -9,7 +9,7 @@ import torch._dynamo
9
  class RotaryPositionalEncoding(nn.Module):
10
  def __init__(self, dim, max_seq_len=2048, base=10000):
11
  super().__init__()
12
- assert dim % 2 == 0
13
 
14
  self.dim = dim
15
  self.max_seq_len = max_seq_len
@@ -31,36 +31,27 @@ class RotaryPositionalEncoding(nn.Module):
31
  self._cached_seq_len = seq_len
32
  return self._cached_freqs[0][:seq_len], self._cached_freqs[1][:seq_len]
33
 
34
- def apply_rotary_pos_emb(self, q, k):
35
- q_len = q.shape[2]
36
- k_len = k.shape[2]
37
- assert q.shape[-1] == self.dim, f"Expected q.shape[-1] == {self.dim}, got {q.shape[-1]}"
38
- assert k.shape[-1] == self.dim, f"Expected k.shape[-1] == {self.dim}, got {k.shape[-1]}"
39
- assert q_len <= self.max_seq_len, f"seq_len {q_len} exceeds max_seq_len {self.max_seq_len}"
40
- assert k_len <= self.max_seq_len, f"seq_len {k_len} exceeds max_seq_len {self.max_seq_len}"
41
-
42
- device = q.device
43
- cos_q, sin_q = self._get_freqs(q_len, device) # both [seq_len, dim//2]
44
- cos_k, sin_k = self._get_freqs(k_len, device) # both [seq_len, dim//2]
45
-
46
 
47
- # Expand to match q/k: [1, 1, seq_len, dim//2]
48
- cos_q = cos_q[None, None, :, :].expand(q.shape[0], q.shape[1], -1, -1)
49
- sin_q = sin_q[None, None, :, :].expand(q.shape[0], q.shape[1], -1, -1)
50
- cos_k = cos_k[None, None, :, :].expand(q.shape[0], q.shape[1], -1, -1)
51
- sin_k = sin_k[None, None, :, :].expand(q.shape[0], q.shape[1], -1, -1)
52
 
53
- def apply(x,cos, sin):
54
- x1 = x[..., ::2]
55
- x2 = x[..., 1::2]
56
-
57
- x_rotated_even = x1 * cos - x2 * sin
58
- x_rotated_odd = x1 * sin + x2 * cos
59
- return torch.stack((x_rotated_even, x_rotated_odd), dim=-1).flatten(-2)
60
 
61
- q_rot = apply(q, cos_q, sin_q)
62
- k_rot = apply(k, cos_k, sin_k)
63
- return q_rot, k_rot
 
64
 
65
  class KNN():
66
  def __init__(self, dim, max_memories, process_rank=0):
@@ -150,18 +141,19 @@ class XLAttention(nn.Module):
150
  self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
151
  self.n_embd = config.n_embd
152
  self.head_dim = config.n_embd // config.n_head
153
- self.kv_head_dim = config.n_embd // self.n_kv_head
154
  self.group_size = self.n_head // self.n_kv_head
155
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
156
  self.scale = self.head_dim ** -0.5
157
-
158
  self.q_proj = nn.Linear(config.n_embd, config.n_embd)
159
  self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
160
  self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
161
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
162
  self.c_proj.MEMGPT_SCALE_INIT = 1
163
 
164
- self.rope = RotaryPositionalEncoding(self.head_dim)
 
165
 
166
  def forward(self, x, xl_memory=None):
167
  B, T, C = x.size()
@@ -184,7 +176,8 @@ class XLAttention(nn.Module):
184
 
185
  # Apply rotary positional encoding
186
  seq_len = k.shape[2]
187
- q, k = self.rope.apply_rotary_pos_emb(q, k)
 
188
 
189
  k = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
190
  v = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
@@ -226,7 +219,7 @@ class KNNAttention(nn.Module):
226
  self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
227
  self.n_embd = config.n_embd
228
  self.head_dim = config.n_embd // config.n_head
229
- self.kv_head_dim = config.n_embd // self.n_kv_head
230
  self.group_size = self.n_head // self.n_kv_head
231
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
232
  self.scale = self.head_dim ** -0.5
@@ -241,7 +234,8 @@ class KNNAttention(nn.Module):
241
  self.topk_retrieved_memories = topk_retrieved_memories
242
  self.knn = knn
243
 
244
- self.rope = RotaryPositionalEncoding(self.head_dim)
 
245
 
246
  def forward(self, x, xl_memory=None):
247
  B, T, C = x.size()
@@ -265,7 +259,8 @@ class KNNAttention(nn.Module):
265
  v = v.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, seq_len, kv_head_dim) # GQAchange
266
 
267
  seq_len = k.shape[2]
268
- q, k = self.rope.apply_rotary_pos_emb(q, k)
 
269
 
270
  k_expanded = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
271
  v_expanded = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
@@ -279,9 +274,12 @@ class KNNAttention(nn.Module):
279
  local_out = att @ v_expanded
280
 
281
  # KNN ATTENTION
 
282
  if self.knn.index.ntotal > 0:
283
- q_search = q.transpose(1, 2).contiguous().view(B, T, C)
284
- mem_kv = self.knn.search(q_search, topk=self.topk_retrieved_memories)
 
 
285
  mem_k, mem_v = mem_kv.unbind(dim=-2)
286
 
287
  # Reshape memory K,V according to KV head structure
 
9
  class RotaryPositionalEncoding(nn.Module):
10
  def __init__(self, dim, max_seq_len=2048, base=10000):
11
  super().__init__()
12
+ assert dim % 2 == 0, f"Dimension {dim} must be even for RoPE"
13
 
14
  self.dim = dim
15
  self.max_seq_len = max_seq_len
 
31
  self._cached_seq_len = seq_len
32
  return self._cached_freqs[0][:seq_len], self._cached_freqs[1][:seq_len]
33
 
34
+ def apply_rotary_pos_emb(self, x, seq_len=None):
35
+ if seq_len is None:
36
+ seq_len = x.shape[2]
37
+
38
+ assert x.shape[-1] == self.dim, f"Expected x.shape[-1] == {self.dim}, got {x.shape[-1]}"
39
+ assert seq_len <= self.max_seq_len, f"seq_len {seq_len} exceeds max_seq_len {self.max_seq_len}"
 
 
 
 
 
 
40
 
41
+ device = x.device
42
+ cos, sin = self._get_freqs(seq_len, device) # both [seq_len, dim//2]
 
 
 
43
 
44
+ # Expand to match x: [1, 1, seq_len, dim//2]
45
+ cos = cos[None, None, :, :].expand(x.shape[0], x.shape[1], -1, -1)
46
+ sin = sin[None, None, :, :].expand(x.shape[0], x.shape[1], -1, -1)
47
+
48
+ x1 = x[..., ::2] # even indices
49
+ x2 = x[..., 1::2] # odd indices
 
50
 
51
+ x_rotated_even = x1 * cos - x2 * sin
52
+ x_rotated_odd = x1 * sin + x2 * cos
53
+
54
+ return torch.stack((x_rotated_even, x_rotated_odd), dim=-1).flatten(-2)
55
 
56
  class KNN():
57
  def __init__(self, dim, max_memories, process_rank=0):
 
141
  self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
142
  self.n_embd = config.n_embd
143
  self.head_dim = config.n_embd // config.n_head
144
+ self.kv_head_dim = self.head_dim
145
  self.group_size = self.n_head // self.n_kv_head
146
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
147
  self.scale = self.head_dim ** -0.5
148
+
149
  self.q_proj = nn.Linear(config.n_embd, config.n_embd)
150
  self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
151
  self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.kv_head_dim)
152
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
153
  self.c_proj.MEMGPT_SCALE_INIT = 1
154
 
155
+ self.rope_q = RotaryPositionalEncoding(self.head_dim)
156
+ self.rope_k = RotaryPositionalEncoding(self.kv_head_dim)
157
 
158
  def forward(self, x, xl_memory=None):
159
  B, T, C = x.size()
 
176
 
177
  # Apply rotary positional encoding
178
  seq_len = k.shape[2]
179
+ q = self.rope_q.apply_rotary_pos_emb(q)
180
+ k = self.rope_k.apply_rotary_pos_emb(k)
181
 
182
  k = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
183
  v = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, T+xl, kv_head_dim)
 
219
  self.n_kv_head = getattr(config, 'n_kv_head', config.n_head)
220
  self.n_embd = config.n_embd
221
  self.head_dim = config.n_embd // config.n_head
222
+ self.kv_head_dim = self.head_dim
223
  self.group_size = self.n_head // self.n_kv_head
224
  self.dropout = nn.Dropout(config.dropout if hasattr(config, 'dropout') else 0.0)
225
  self.scale = self.head_dim ** -0.5
 
234
  self.topk_retrieved_memories = topk_retrieved_memories
235
  self.knn = knn
236
 
237
+ self.rope_q = RotaryPositionalEncoding(self.head_dim)
238
+ self.rope_k = RotaryPositionalEncoding(self.kv_head_dim)
239
 
240
  def forward(self, x, xl_memory=None):
241
  B, T, C = x.size()
 
259
  v = v.view(B, -1, self.n_kv_head, self.kv_head_dim).transpose(1, 2) # (B, n_kv_head, seq_len, kv_head_dim) # GQAchange
260
 
261
  seq_len = k.shape[2]
262
+ q = self.rope_q.apply_rotary_pos_emb(q)
263
+ k = self.rope_k.apply_rotary_pos_emb(k)
264
 
265
  k_expanded = k.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
266
  v_expanded = v.repeat_interleave(self.group_size, dim=1) # (B, n_head, seq_len, kv_head_dim)
 
274
  local_out = att @ v_expanded
275
 
276
  # KNN ATTENTION
277
+ #Making some modifications to the query shape for searching in the db, which is different from the original paper.
278
  if self.knn.index.ntotal > 0:
279
+ q_grouped = q.view(B, self.n_kv_head, self.group_size, T, self.head_dim) #(B, n_head, T, head_dim) -> (B, n_kv_head, group_size, T, head_dim)
280
+ q_mean = q_grouped.mean(dim=2) # (B, 4, T, 64) , took average across the 3 heads in each group
281
+ q_knn = q_mean.transpose(1, 2).contiguous().view(B, T, -1) # (B, T, 256)
282
+ mem_kv = self.knn.search(q_knn, topk=self.topk_retrieved_memories)
283
  mem_k, mem_v = mem_kv.unbind(dim=-2)
284
 
285
  # Reshape memory K,V according to KV head structure
model_core/model.py CHANGED
@@ -52,9 +52,10 @@ class GPT(nn.Module):
52
  super().__init__()
53
  self.config = config
54
  self.process_rank = process_rank
 
55
 
56
  # Initialize KNN memory
57
- self.knn = KNN(config.n_embd, config.max_knn_memories, process_rank)
58
 
59
  self.transformer = nn.ModuleDict(dict(
60
  wte=nn.Embedding(config.vocab_size, config.n_embd),
@@ -120,6 +121,13 @@ class GPT(nn.Module):
120
  return logits, loss
121
 
122
  def configure_optimizers(self, weight_decay, learning_rate, device_type, master_process):
 
 
 
 
 
 
 
123
  # Get all parameters that require grad
124
  param_dict = {pn: p for pn, p in self.named_parameters()}
125
  param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
 
52
  super().__init__()
53
  self.config = config
54
  self.process_rank = process_rank
55
+ kv_dim = config.n_kv_head * (config.n_embd // config.n_head)
56
 
57
  # Initialize KNN memory
58
+ self.knn = KNN(kv_dim, config.max_knn_memories, process_rank)
59
 
60
  self.transformer = nn.ModuleDict(dict(
61
  wte=nn.Embedding(config.vocab_size, config.n_embd),
 
121
  return logits, loss
122
 
123
  def configure_optimizers(self, weight_decay, learning_rate, device_type, master_process):
124
+ #print model parameters
125
+ total_params = sum(p.numel() for p in self.parameters())
126
+ print(f"Total parameters: {total_params}")
127
+ # Trainable parameters
128
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
129
+ print(f"Trainable parameters: {trainable_params}")
130
+
131
  # Get all parameters that require grad
132
  param_dict = {pn: p for pn, p in self.named_parameters()}
133
  param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
model_core/training.py CHANGED
@@ -103,7 +103,6 @@ def train_memgpt(config_path,dataloader_class=None):
103
  for step in range(max_steps):
104
  t0 = time.time()
105
  last_step = (step == max_steps - 1)
106
- print(f"validation loop.step={step}")
107
  if step % 350 == 0 or last_step:
108
  model.eval()
109
  val_loader.reset()
@@ -168,7 +167,6 @@ def train_memgpt(config_path,dataloader_class=None):
168
  loss_accum = 0.0
169
 
170
  for micro_step in range(grad_accum_steps):
171
- print(f"micro tep= {micro_step}")
172
  x, y, current_shard_num = train_loader.next_batch()
173
  x, y = x.to(device), y.to(device)
174
 
 
103
  for step in range(max_steps):
104
  t0 = time.time()
105
  last_step = (step == max_steps - 1)
 
106
  if step % 350 == 0 or last_step:
107
  model.eval()
108
  val_loader.reset()
 
167
  loss_accum = 0.0
168
 
169
  for micro_step in range(grad_accum_steps):
 
170
  x, y, current_shard_num = train_loader.next_batch()
171
  x, y = x.to(device), y.to(device)
172
 
requirements.txt2 DELETED
Binary file (2.52 kB)
 
rough_work.py DELETED
File without changes
scripts/generate.py CHANGED
@@ -1,63 +1,2 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import tiktoken
4
- from model import GPT
5
 
6
- def generate_text(model, prompt, num_return_sequences=4, max_length=32, device='cuda'):
7
- model.eval()
8
- enc = tiktoken.get_encoding('gpt2')
9
- tokens = enc.encode(prompt)
10
- tokens = torch.tensor(tokens, dtype=torch.long)
11
- tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)
12
- xgen = tokens.to(device)
13
- sample_rng = torch.Generator(device=device)
14
- sample_rng.manual_seed(42)
15
-
16
- while xgen.size(1) < max_length:
17
- with torch.no_grad():
18
- logits, loss = model(xgen) # (B, T, vocab_size)
19
- logits = logits[:, -1, :] # (B, vocab_size)
20
- probs = F.softmax(logits, dim=-1)
21
- topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)
22
- ix = torch.multinomial(topk_probs, 1, generator=sample_rng)
23
- xcol = torch.gather(topk_indices, -1, ix)
24
- xgen = torch.cat((xgen, xcol), dim=1)
25
-
26
- generated_texts = []
27
- for i in range(num_return_sequences):
28
- tokens = xgen[i, :max_length].tolist()
29
- decoded = enc.decode(tokens)
30
- generated_texts.append(decoded)
31
- print(f"Sample {i + 1}: {decoded}")
32
-
33
-
34
- return generated_texts
35
-
36
-
37
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
38
- print(f"running with {device}")
39
-
40
-
41
- checkpoint_path = 'log/model_final.pt'
42
-
43
- print(f"Loading checkpoint from {checkpoint_path}")
44
- checkpoint = torch.load(checkpoint_path,map_location=device)
45
- model_config = checkpoint['config']
46
- model_config.vocab_size = 50304
47
- model = GPT(model_config)
48
-
49
-
50
- model.load_state_dict(checkpoint['model'])
51
- model.to(device)
52
-
53
-
54
-
55
- prompt = "Hello, I'm a language model,"
56
-
57
- generated_texts = generate_text(
58
- model=model,
59
- prompt=prompt,
60
- num_return_sequences=4,
61
- max_length=32,
62
- device=device
63
- )
 
 
 
 
 
1
 
2
+ #Inference part not completed