jaeikkim commited on
Commit
50db107
·
1 Parent(s): 333ef29

Cleanup binaries before space push

Browse files
MMaDA/inference/common.py CHANGED
@@ -56,6 +56,14 @@ def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
56
  cond_dropout_prob=cfg.training.cond_dropout_prob,
57
  use_reserved_token=True,
58
  )
 
 
 
 
 
 
 
 
59
  return uni_prompting, tokenizer
60
 
61
 
 
56
  cond_dropout_prob=cfg.training.cond_dropout_prob,
57
  use_reserved_token=True,
58
  )
59
+ # Safety: if newer task tokens are missing (e.g., <|ti2ti|>, <|t2ti|>), inject them.
60
+ for tok in ("<|ti2ti|>", "<|t2ti|>"):
61
+ if tok not in uni_prompting.sptids_dict:
62
+ token_id = tokenizer.convert_tokens_to_ids(tok)
63
+ if token_id is None or token_id == tokenizer.unk_token_id:
64
+ tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
65
+ token_id = tokenizer.convert_tokens_to_ids(tok)
66
+ uni_prompting.sptids_dict[tok] = torch.tensor([token_id])
67
  return uni_prompting, tokenizer
68
 
69
 
MMaDA/inference/gradio_multimodal_demo_inst.py CHANGED
@@ -1267,9 +1267,17 @@ class OmadaDemo:
1267
  prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
1268
  prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
1269
 
1270
- ti2ti_id = int(self.uni_prompting.sptids_dict['<|ti2ti|>'][0].item())
1271
- soi_id = int(self.uni_prompting.sptids_dict['<|soi|>'][0].item())
1272
- eoi_id = int(self.uni_prompting.sptids_dict['<|eoi|>'][0].item())
 
 
 
 
 
 
 
 
1273
  pad_raw = getattr(self.uni_prompting, "pad_id", 0)
1274
  pad_id = int(pad_raw if pad_raw is not None else 0)
1275
 
 
1267
  prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
1268
  prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
1269
 
1270
+ def _get_token(key: str):
1271
+ tok = self.uni_prompting.sptids_dict.get(key)
1272
+ if tok is None or tok.numel() == 0:
1273
+ return None
1274
+ return int(tok[0].item())
1275
+
1276
+ ti2ti_id = _get_token('<|ti2ti|>')
1277
+ soi_id = _get_token('<|soi|>')
1278
+ eoi_id = _get_token('<|eoi|>')
1279
+ if ti2ti_id is None or soi_id is None or eoi_id is None:
1280
+ return None, "", "TI2TI special tokens are missing in the tokenizer/config."
1281
  pad_raw = getattr(self.uni_prompting, "pad_id", 0)
1282
  pad_id = int(pad_raw if pad_raw is not None else 0)
1283