Spaces:
Running
on
Zero
Running
on
Zero
Cleanup binaries before space push
Browse files
MMaDA/inference/common.py
CHANGED
|
@@ -56,6 +56,14 @@ def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
|
|
| 56 |
cond_dropout_prob=cfg.training.cond_dropout_prob,
|
| 57 |
use_reserved_token=True,
|
| 58 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
return uni_prompting, tokenizer
|
| 60 |
|
| 61 |
|
|
|
|
| 56 |
cond_dropout_prob=cfg.training.cond_dropout_prob,
|
| 57 |
use_reserved_token=True,
|
| 58 |
)
|
| 59 |
+
# Safety: if newer task tokens are missing (e.g., <|ti2ti|>, <|t2ti|>), inject them.
|
| 60 |
+
for tok in ("<|ti2ti|>", "<|t2ti|>"):
|
| 61 |
+
if tok not in uni_prompting.sptids_dict:
|
| 62 |
+
token_id = tokenizer.convert_tokens_to_ids(tok)
|
| 63 |
+
if token_id is None or token_id == tokenizer.unk_token_id:
|
| 64 |
+
tokenizer.add_special_tokens({"additional_special_tokens": [tok]})
|
| 65 |
+
token_id = tokenizer.convert_tokens_to_ids(tok)
|
| 66 |
+
uni_prompting.sptids_dict[tok] = torch.tensor([token_id])
|
| 67 |
return uni_prompting, tokenizer
|
| 68 |
|
| 69 |
|
MMaDA/inference/gradio_multimodal_demo_inst.py
CHANGED
|
@@ -1267,9 +1267,17 @@ class OmadaDemo:
|
|
| 1267 |
prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
|
| 1268 |
prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
|
| 1269 |
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1273 |
pad_raw = getattr(self.uni_prompting, "pad_id", 0)
|
| 1274 |
pad_id = int(pad_raw if pad_raw is not None else 0)
|
| 1275 |
|
|
|
|
| 1267 |
prompt_ids = prompt_ids + [self.uni_prompting.text_tokenizer.eos_token_id]
|
| 1268 |
prompt_tensor = torch.tensor(prompt_ids, device=self.device, dtype=torch.long)
|
| 1269 |
|
| 1270 |
+
def _get_token(key: str):
|
| 1271 |
+
tok = self.uni_prompting.sptids_dict.get(key)
|
| 1272 |
+
if tok is None or tok.numel() == 0:
|
| 1273 |
+
return None
|
| 1274 |
+
return int(tok[0].item())
|
| 1275 |
+
|
| 1276 |
+
ti2ti_id = _get_token('<|ti2ti|>')
|
| 1277 |
+
soi_id = _get_token('<|soi|>')
|
| 1278 |
+
eoi_id = _get_token('<|eoi|>')
|
| 1279 |
+
if ti2ti_id is None or soi_id is None or eoi_id is None:
|
| 1280 |
+
return None, "", "TI2TI special tokens are missing in the tokenizer/config."
|
| 1281 |
pad_raw = getattr(self.uni_prompting, "pad_id", 0)
|
| 1282 |
pad_id = int(pad_raw if pad_raw is not None else 0)
|
| 1283 |
|