Spaces:
Running
on
Zero
Running
on
Zero
Cleanup binaries before space push
Browse files- MMaDA/inference/common.py +1 -0
- MMaDA/models/modeling_omada.py +195 -0
MMaDA/inference/common.py
CHANGED
|
@@ -51,6 +51,7 @@ def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
|
|
| 51 |
"<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>",
|
| 52 |
"<|i2i|>", "<|v2s|>", "<|s2s|>",
|
| 53 |
"<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>",
|
|
|
|
| 54 |
),
|
| 55 |
ignore_id=-100,
|
| 56 |
cond_dropout_prob=cfg.training.cond_dropout_prob,
|
|
|
|
| 51 |
"<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>",
|
| 52 |
"<|i2i|>", "<|v2s|>", "<|s2s|>",
|
| 53 |
"<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>",
|
| 54 |
+
"<|ti2ti|>", "<|t2ti|>",
|
| 55 |
),
|
| 56 |
ignore_id=-100,
|
| 57 |
cond_dropout_prob=cfg.training.cond_dropout_prob,
|
MMaDA/models/modeling_omada.py
CHANGED
|
@@ -1802,6 +1802,201 @@ class OMadaModelLM(LLaDAModelLM):
|
|
| 1802 |
|
| 1803 |
return x
|
| 1804 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1805 |
@torch.no_grad()
|
| 1806 |
def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
|
| 1807 |
"""
|
|
|
|
| 1802 |
|
| 1803 |
return x
|
| 1804 |
|
| 1805 |
+
@torch.no_grad()
|
| 1806 |
+
def ti2ti_generate(
|
| 1807 |
+
self,
|
| 1808 |
+
input_ids: torch.LongTensor = None,
|
| 1809 |
+
uncond_input_ids: torch.LongTensor = None,
|
| 1810 |
+
attention_mask=None,
|
| 1811 |
+
uncond_attention_mask=None,
|
| 1812 |
+
temperature=1.0,
|
| 1813 |
+
timesteps=18,
|
| 1814 |
+
timesteps_text: int | None = None,
|
| 1815 |
+
timesteps_image: int | None = None,
|
| 1816 |
+
guidance_scale=0,
|
| 1817 |
+
noise_schedule=cosine_schedule,
|
| 1818 |
+
generator: torch.Generator = None,
|
| 1819 |
+
config=None,
|
| 1820 |
+
seq_len=1024,
|
| 1821 |
+
mask_token_id=126336,
|
| 1822 |
+
resolution=512,
|
| 1823 |
+
codebook_size=8192,
|
| 1824 |
+
uni_prompting=None,
|
| 1825 |
+
**kwargs,
|
| 1826 |
+
):
|
| 1827 |
+
"""
|
| 1828 |
+
TI2TI generation that fills masked text and image tokens; allows separate timesteps.
|
| 1829 |
+
Returns (filled_tokens, decoded_texts).
|
| 1830 |
+
"""
|
| 1831 |
+
if input_ids is None or attention_mask is None:
|
| 1832 |
+
raise ValueError("input_ids and attention_mask are required for ti2ti_generate.")
|
| 1833 |
+
if uni_prompting is None:
|
| 1834 |
+
raise ValueError("uni_prompting is required for ti2ti_generate.")
|
| 1835 |
+
|
| 1836 |
+
device = input_ids.device
|
| 1837 |
+
text_vocab_size = len(uni_prompting.text_tokenizer)
|
| 1838 |
+
image_vocab_start = text_vocab_size
|
| 1839 |
+
image_vocab_end = image_vocab_start + codebook_size
|
| 1840 |
+
timesteps_text = timesteps if timesteps_text is None else timesteps_text
|
| 1841 |
+
timesteps_image = timesteps if timesteps_image is None else timesteps_image
|
| 1842 |
+
|
| 1843 |
+
seq = input_ids.clone()
|
| 1844 |
+
if attention_mask is None:
|
| 1845 |
+
attn = torch.ones_like(seq, dtype=torch.long)
|
| 1846 |
+
else:
|
| 1847 |
+
attn = attention_mask
|
| 1848 |
+
use_guidance = uncond_input_ids is not None and guidance_scale > 0
|
| 1849 |
+
if use_guidance:
|
| 1850 |
+
seq_uncond = uncond_input_ids.clone()
|
| 1851 |
+
if uncond_attention_mask is None:
|
| 1852 |
+
attn_uncond = torch.ones_like(seq_uncond, dtype=torch.long)
|
| 1853 |
+
else:
|
| 1854 |
+
attn_uncond = uncond_attention_mask
|
| 1855 |
+
else:
|
| 1856 |
+
seq_uncond = None
|
| 1857 |
+
attn_uncond = None
|
| 1858 |
+
total_len = seq.shape[1]
|
| 1859 |
+
|
| 1860 |
+
def _uniform_transfer_plan(mask_bool: torch.Tensor, steps_count: int) -> Optional[torch.Tensor]:
|
| 1861 |
+
"""Evenly divide masked token updates across steps."""
|
| 1862 |
+
if steps_count is None or steps_count <= 0:
|
| 1863 |
+
return None
|
| 1864 |
+
mask_num = mask_bool.sum(dim=1, keepdim=True)
|
| 1865 |
+
if mask_num.numel() == 0:
|
| 1866 |
+
return None
|
| 1867 |
+
base = mask_num // steps_count
|
| 1868 |
+
remainder = mask_num % steps_count
|
| 1869 |
+
plan = torch.zeros(mask_num.size(0), steps_count, device=mask_bool.device, dtype=torch.int64) + base
|
| 1870 |
+
for idx in range(mask_num.size(0)):
|
| 1871 |
+
rem_val = remainder[idx].item()
|
| 1872 |
+
if rem_val > 0:
|
| 1873 |
+
plan[idx, :rem_val] += 1
|
| 1874 |
+
return plan
|
| 1875 |
+
|
| 1876 |
+
prompt_block_len = uni_prompting.max_text_len
|
| 1877 |
+
soi_id = int(uni_prompting.sptids_dict.get("<|soi|>", torch.tensor([-1]))[0].item())
|
| 1878 |
+
eoi_id = int(uni_prompting.sptids_dict.get("<|eoi|>", torch.tensor([-1]))[0].item())
|
| 1879 |
+
pad_id = int(getattr(uni_prompting, "pad_id", 0))
|
| 1880 |
+
|
| 1881 |
+
def _locate_blocks(sample_seq: torch.Tensor, sample_attn: Optional[torch.Tensor]):
|
| 1882 |
+
# Find second (target) soi/eoi pair; fallback to template formula.
|
| 1883 |
+
soi_positions = (sample_seq == soi_id).nonzero(as_tuple=True)[0]
|
| 1884 |
+
eoi_positions = (sample_seq == eoi_id).nonzero(as_tuple=True)[0]
|
| 1885 |
+
tgt_soi = None
|
| 1886 |
+
tgt_eoi = None
|
| 1887 |
+
if soi_positions.numel() >= 2:
|
| 1888 |
+
tgt_soi = int(soi_positions[1].item())
|
| 1889 |
+
tgt_eoi_candidates = [int(e.item()) for e in eoi_positions if int(e.item()) > tgt_soi]
|
| 1890 |
+
if tgt_eoi_candidates:
|
| 1891 |
+
tgt_eoi = tgt_eoi_candidates[0]
|
| 1892 |
+
|
| 1893 |
+
if tgt_soi is None or tgt_eoi is None:
|
| 1894 |
+
# fallback: compute with pad offset the old way
|
| 1895 |
+
non_pad = (sample_seq != pad_id).nonzero(as_tuple=True)
|
| 1896 |
+
pad_offset = int(non_pad[0][0].item()) if len(non_pad) > 0 and non_pad[0].numel() > 0 else 0
|
| 1897 |
+
tgt_soi = pad_offset + 1 + 1 + seq_len + 1 + prompt_block_len + 1 # soi before target img
|
| 1898 |
+
tgt_eoi = tgt_soi + seq_len + 1 # eoi after target img
|
| 1899 |
+
|
| 1900 |
+
img_start_local = tgt_soi + 1
|
| 1901 |
+
img_end_local = min(tgt_eoi, sample_seq.size(0))
|
| 1902 |
+
|
| 1903 |
+
if sample_attn is not None:
|
| 1904 |
+
text_attn = sample_attn[tgt_eoi + 1 :]
|
| 1905 |
+
nonzero = (text_attn != 0).nonzero(as_tuple=True)
|
| 1906 |
+
if len(nonzero) > 0 and nonzero[0].numel() > 0:
|
| 1907 |
+
last_idx = int(nonzero[0][-1].item())
|
| 1908 |
+
text_end_local = tgt_eoi + 1 + last_idx + 1
|
| 1909 |
+
else:
|
| 1910 |
+
text_end_local = tgt_eoi + 1 + prompt_block_len
|
| 1911 |
+
else:
|
| 1912 |
+
text_end_local = tgt_eoi + 1 + prompt_block_len
|
| 1913 |
+
text_start_local = tgt_eoi + 1
|
| 1914 |
+
text_end_local = min(text_end_local, sample_seq.size(0))
|
| 1915 |
+
return img_start_local, img_end_local, text_start_local, text_end_local
|
| 1916 |
+
|
| 1917 |
+
img_start, img_end, text_start, text_end = _locate_blocks(seq[0], attn[0] if attn is not None else None)
|
| 1918 |
+
text_indices = torch.arange(total_len, device=device)
|
| 1919 |
+
initial_text_mask = (seq == mask_token_id) & (text_indices >= text_start) & (text_indices < text_end)
|
| 1920 |
+
text_transfer_plan = _uniform_transfer_plan(initial_text_mask, timesteps_text)
|
| 1921 |
+
text_step_idx = 0
|
| 1922 |
+
|
| 1923 |
+
# Simultaneous fill: at each step, update image/text masks that still remain
|
| 1924 |
+
max_steps = max(timesteps_image, timesteps_text)
|
| 1925 |
+
for step in range(max_steps):
|
| 1926 |
+
mask_map = seq == mask_token_id
|
| 1927 |
+
img_mask = mask_map & (text_indices >= img_start) & (text_indices < img_end) if step < timesteps_image else None
|
| 1928 |
+
text_mask = mask_map & (text_indices >= text_start) & (text_indices < text_end) if step < timesteps_text else None
|
| 1929 |
+
if not ((img_mask is not None and img_mask.any()) or (text_mask is not None and text_mask.any())):
|
| 1930 |
+
break
|
| 1931 |
+
|
| 1932 |
+
attn_bias = (attn[:, :, None] & attn[:, None, :]).bool().unsqueeze(1)
|
| 1933 |
+
logits_cond = self(seq, attention_bias=attn_bias).logits
|
| 1934 |
+
if use_guidance:
|
| 1935 |
+
attn_bias_uncond = (attn_uncond[:, :, None] & attn_uncond[:, None, :]).bool().unsqueeze(1)
|
| 1936 |
+
logits_uncond = self(seq_uncond, attention_bias=attn_bias_uncond).logits
|
| 1937 |
+
logits = logits_uncond + (guidance_scale + 1.0) * (logits_cond - logits_uncond)
|
| 1938 |
+
else:
|
| 1939 |
+
logits = logits_cond
|
| 1940 |
+
|
| 1941 |
+
if text_mask is not None and text_mask.any():
|
| 1942 |
+
logits_text = logits[..., :text_vocab_size]
|
| 1943 |
+
probs_text = logits_text.softmax(dim=-1)
|
| 1944 |
+
sampled_text = torch.multinomial(
|
| 1945 |
+
probs_text.view(-1, text_vocab_size),
|
| 1946 |
+
1,
|
| 1947 |
+
replacement=False
|
| 1948 |
+
).view(*logits_text.shape[:2])
|
| 1949 |
+
sampled_probs = torch.gather(
|
| 1950 |
+
probs_text, dim=-1, index=sampled_text.unsqueeze(-1)
|
| 1951 |
+
).squeeze(-1)
|
| 1952 |
+
candidate_seq = torch.where(text_mask, sampled_text, seq)
|
| 1953 |
+
confidence = torch.full_like(sampled_probs, float("-inf"))
|
| 1954 |
+
confidence = torch.where(text_mask, sampled_probs, confidence)
|
| 1955 |
+
if text_transfer_plan is not None and text_step_idx < text_transfer_plan.shape[1]:
|
| 1956 |
+
transfer_counts = text_transfer_plan[:, text_step_idx]
|
| 1957 |
+
else:
|
| 1958 |
+
transfer_counts = text_mask.sum(dim=1)
|
| 1959 |
+
transfer_mask = torch.zeros_like(text_mask, dtype=torch.bool)
|
| 1960 |
+
for b_idx in range(seq.shape[0]):
|
| 1961 |
+
mask_count = int(text_mask[b_idx].sum().item())
|
| 1962 |
+
if mask_count == 0:
|
| 1963 |
+
continue
|
| 1964 |
+
k = int(min(max(transfer_counts[b_idx].item(), 0), mask_count))
|
| 1965 |
+
if k <= 0:
|
| 1966 |
+
continue
|
| 1967 |
+
_, top_idx = torch.topk(confidence[b_idx], k=k)
|
| 1968 |
+
transfer_mask[b_idx, top_idx] = True
|
| 1969 |
+
if transfer_mask.any():
|
| 1970 |
+
seq = torch.where(transfer_mask, candidate_seq, seq)
|
| 1971 |
+
text_step_idx += 1
|
| 1972 |
+
|
| 1973 |
+
if img_mask is not None and img_mask.any():
|
| 1974 |
+
logits_img = logits[..., image_vocab_start:image_vocab_end]
|
| 1975 |
+
probs_img = logits_img.softmax(dim=-1)
|
| 1976 |
+
sampled_img = torch.multinomial(
|
| 1977 |
+
probs_img.view(-1, codebook_size),
|
| 1978 |
+
1,
|
| 1979 |
+
replacement=False
|
| 1980 |
+
).view(*logits_img.shape[:2]) + image_vocab_start
|
| 1981 |
+
seq = torch.where(img_mask, sampled_img, seq)
|
| 1982 |
+
|
| 1983 |
+
if use_guidance:
|
| 1984 |
+
updated_mask = torch.zeros_like(seq, dtype=torch.bool)
|
| 1985 |
+
if img_mask is not None:
|
| 1986 |
+
updated_mask |= img_mask
|
| 1987 |
+
if text_mask is not None:
|
| 1988 |
+
updated_mask |= text_mask
|
| 1989 |
+
seq_uncond = torch.where(updated_mask, seq, seq_uncond)
|
| 1990 |
+
|
| 1991 |
+
# Decode text tokens from filled sequence
|
| 1992 |
+
pred_texts = []
|
| 1993 |
+
for row in seq:
|
| 1994 |
+
text_tokens = [int(t) for t in row.tolist() if 0 <= t < text_vocab_size]
|
| 1995 |
+
pred_texts.append(uni_prompting.text_tokenizer.decode(text_tokens, skip_special_tokens=True))
|
| 1996 |
+
|
| 1997 |
+
return seq, pred_texts
|
| 1998 |
+
|
| 1999 |
+
|
| 2000 |
@torch.no_grad()
|
| 2001 |
def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
|
| 2002 |
"""
|