jaeikkim commited on
Commit
db39f43
·
1 Parent(s): 50db107

Cleanup binaries before space push

Browse files
MMaDA/inference/common.py CHANGED
@@ -51,6 +51,7 @@ def build_uni_prompting(cfg) -> Tuple[UniversalPrompting, AutoTokenizer]:
51
  "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>",
52
  "<|i2i|>", "<|v2s|>", "<|s2s|>",
53
  "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>",
 
54
  ),
55
  ignore_id=-100,
56
  cond_dropout_prob=cfg.training.cond_dropout_prob,
 
51
  "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>",
52
  "<|i2i|>", "<|v2s|>", "<|s2s|>",
53
  "<|v2t|>", "<|s2t|>", "<|t2s|>", "<|soa|>", "<|eoa|>",
54
+ "<|ti2ti|>", "<|t2ti|>",
55
  ),
56
  ignore_id=-100,
57
  cond_dropout_prob=cfg.training.cond_dropout_prob,
MMaDA/models/modeling_omada.py CHANGED
@@ -1802,6 +1802,201 @@ class OMadaModelLM(LLaDAModelLM):
1802
 
1803
  return x
1804
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1805
  @torch.no_grad()
1806
  def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
1807
  """
 
1802
 
1803
  return x
1804
 
1805
+ @torch.no_grad()
1806
+ def ti2ti_generate(
1807
+ self,
1808
+ input_ids: torch.LongTensor = None,
1809
+ uncond_input_ids: torch.LongTensor = None,
1810
+ attention_mask=None,
1811
+ uncond_attention_mask=None,
1812
+ temperature=1.0,
1813
+ timesteps=18,
1814
+ timesteps_text: int | None = None,
1815
+ timesteps_image: int | None = None,
1816
+ guidance_scale=0,
1817
+ noise_schedule=cosine_schedule,
1818
+ generator: torch.Generator = None,
1819
+ config=None,
1820
+ seq_len=1024,
1821
+ mask_token_id=126336,
1822
+ resolution=512,
1823
+ codebook_size=8192,
1824
+ uni_prompting=None,
1825
+ **kwargs,
1826
+ ):
1827
+ """
1828
+ TI2TI generation that fills masked text and image tokens; allows separate timesteps.
1829
+ Returns (filled_tokens, decoded_texts).
1830
+ """
1831
+ if input_ids is None or attention_mask is None:
1832
+ raise ValueError("input_ids and attention_mask are required for ti2ti_generate.")
1833
+ if uni_prompting is None:
1834
+ raise ValueError("uni_prompting is required for ti2ti_generate.")
1835
+
1836
+ device = input_ids.device
1837
+ text_vocab_size = len(uni_prompting.text_tokenizer)
1838
+ image_vocab_start = text_vocab_size
1839
+ image_vocab_end = image_vocab_start + codebook_size
1840
+ timesteps_text = timesteps if timesteps_text is None else timesteps_text
1841
+ timesteps_image = timesteps if timesteps_image is None else timesteps_image
1842
+
1843
+ seq = input_ids.clone()
1844
+ if attention_mask is None:
1845
+ attn = torch.ones_like(seq, dtype=torch.long)
1846
+ else:
1847
+ attn = attention_mask
1848
+ use_guidance = uncond_input_ids is not None and guidance_scale > 0
1849
+ if use_guidance:
1850
+ seq_uncond = uncond_input_ids.clone()
1851
+ if uncond_attention_mask is None:
1852
+ attn_uncond = torch.ones_like(seq_uncond, dtype=torch.long)
1853
+ else:
1854
+ attn_uncond = uncond_attention_mask
1855
+ else:
1856
+ seq_uncond = None
1857
+ attn_uncond = None
1858
+ total_len = seq.shape[1]
1859
+
1860
+ def _uniform_transfer_plan(mask_bool: torch.Tensor, steps_count: int) -> Optional[torch.Tensor]:
1861
+ """Evenly divide masked token updates across steps."""
1862
+ if steps_count is None or steps_count <= 0:
1863
+ return None
1864
+ mask_num = mask_bool.sum(dim=1, keepdim=True)
1865
+ if mask_num.numel() == 0:
1866
+ return None
1867
+ base = mask_num // steps_count
1868
+ remainder = mask_num % steps_count
1869
+ plan = torch.zeros(mask_num.size(0), steps_count, device=mask_bool.device, dtype=torch.int64) + base
1870
+ for idx in range(mask_num.size(0)):
1871
+ rem_val = remainder[idx].item()
1872
+ if rem_val > 0:
1873
+ plan[idx, :rem_val] += 1
1874
+ return plan
1875
+
1876
+ prompt_block_len = uni_prompting.max_text_len
1877
+ soi_id = int(uni_prompting.sptids_dict.get("<|soi|>", torch.tensor([-1]))[0].item())
1878
+ eoi_id = int(uni_prompting.sptids_dict.get("<|eoi|>", torch.tensor([-1]))[0].item())
1879
+ pad_id = int(getattr(uni_prompting, "pad_id", 0))
1880
+
1881
+ def _locate_blocks(sample_seq: torch.Tensor, sample_attn: Optional[torch.Tensor]):
1882
+ # Find second (target) soi/eoi pair; fallback to template formula.
1883
+ soi_positions = (sample_seq == soi_id).nonzero(as_tuple=True)[0]
1884
+ eoi_positions = (sample_seq == eoi_id).nonzero(as_tuple=True)[0]
1885
+ tgt_soi = None
1886
+ tgt_eoi = None
1887
+ if soi_positions.numel() >= 2:
1888
+ tgt_soi = int(soi_positions[1].item())
1889
+ tgt_eoi_candidates = [int(e.item()) for e in eoi_positions if int(e.item()) > tgt_soi]
1890
+ if tgt_eoi_candidates:
1891
+ tgt_eoi = tgt_eoi_candidates[0]
1892
+
1893
+ if tgt_soi is None or tgt_eoi is None:
1894
+ # fallback: compute with pad offset the old way
1895
+ non_pad = (sample_seq != pad_id).nonzero(as_tuple=True)
1896
+ pad_offset = int(non_pad[0][0].item()) if len(non_pad) > 0 and non_pad[0].numel() > 0 else 0
1897
+ tgt_soi = pad_offset + 1 + 1 + seq_len + 1 + prompt_block_len + 1 # soi before target img
1898
+ tgt_eoi = tgt_soi + seq_len + 1 # eoi after target img
1899
+
1900
+ img_start_local = tgt_soi + 1
1901
+ img_end_local = min(tgt_eoi, sample_seq.size(0))
1902
+
1903
+ if sample_attn is not None:
1904
+ text_attn = sample_attn[tgt_eoi + 1 :]
1905
+ nonzero = (text_attn != 0).nonzero(as_tuple=True)
1906
+ if len(nonzero) > 0 and nonzero[0].numel() > 0:
1907
+ last_idx = int(nonzero[0][-1].item())
1908
+ text_end_local = tgt_eoi + 1 + last_idx + 1
1909
+ else:
1910
+ text_end_local = tgt_eoi + 1 + prompt_block_len
1911
+ else:
1912
+ text_end_local = tgt_eoi + 1 + prompt_block_len
1913
+ text_start_local = tgt_eoi + 1
1914
+ text_end_local = min(text_end_local, sample_seq.size(0))
1915
+ return img_start_local, img_end_local, text_start_local, text_end_local
1916
+
1917
+ img_start, img_end, text_start, text_end = _locate_blocks(seq[0], attn[0] if attn is not None else None)
1918
+ text_indices = torch.arange(total_len, device=device)
1919
+ initial_text_mask = (seq == mask_token_id) & (text_indices >= text_start) & (text_indices < text_end)
1920
+ text_transfer_plan = _uniform_transfer_plan(initial_text_mask, timesteps_text)
1921
+ text_step_idx = 0
1922
+
1923
+ # Simultaneous fill: at each step, update image/text masks that still remain
1924
+ max_steps = max(timesteps_image, timesteps_text)
1925
+ for step in range(max_steps):
1926
+ mask_map = seq == mask_token_id
1927
+ img_mask = mask_map & (text_indices >= img_start) & (text_indices < img_end) if step < timesteps_image else None
1928
+ text_mask = mask_map & (text_indices >= text_start) & (text_indices < text_end) if step < timesteps_text else None
1929
+ if not ((img_mask is not None and img_mask.any()) or (text_mask is not None and text_mask.any())):
1930
+ break
1931
+
1932
+ attn_bias = (attn[:, :, None] & attn[:, None, :]).bool().unsqueeze(1)
1933
+ logits_cond = self(seq, attention_bias=attn_bias).logits
1934
+ if use_guidance:
1935
+ attn_bias_uncond = (attn_uncond[:, :, None] & attn_uncond[:, None, :]).bool().unsqueeze(1)
1936
+ logits_uncond = self(seq_uncond, attention_bias=attn_bias_uncond).logits
1937
+ logits = logits_uncond + (guidance_scale + 1.0) * (logits_cond - logits_uncond)
1938
+ else:
1939
+ logits = logits_cond
1940
+
1941
+ if text_mask is not None and text_mask.any():
1942
+ logits_text = logits[..., :text_vocab_size]
1943
+ probs_text = logits_text.softmax(dim=-1)
1944
+ sampled_text = torch.multinomial(
1945
+ probs_text.view(-1, text_vocab_size),
1946
+ 1,
1947
+ replacement=False
1948
+ ).view(*logits_text.shape[:2])
1949
+ sampled_probs = torch.gather(
1950
+ probs_text, dim=-1, index=sampled_text.unsqueeze(-1)
1951
+ ).squeeze(-1)
1952
+ candidate_seq = torch.where(text_mask, sampled_text, seq)
1953
+ confidence = torch.full_like(sampled_probs, float("-inf"))
1954
+ confidence = torch.where(text_mask, sampled_probs, confidence)
1955
+ if text_transfer_plan is not None and text_step_idx < text_transfer_plan.shape[1]:
1956
+ transfer_counts = text_transfer_plan[:, text_step_idx]
1957
+ else:
1958
+ transfer_counts = text_mask.sum(dim=1)
1959
+ transfer_mask = torch.zeros_like(text_mask, dtype=torch.bool)
1960
+ for b_idx in range(seq.shape[0]):
1961
+ mask_count = int(text_mask[b_idx].sum().item())
1962
+ if mask_count == 0:
1963
+ continue
1964
+ k = int(min(max(transfer_counts[b_idx].item(), 0), mask_count))
1965
+ if k <= 0:
1966
+ continue
1967
+ _, top_idx = torch.topk(confidence[b_idx], k=k)
1968
+ transfer_mask[b_idx, top_idx] = True
1969
+ if transfer_mask.any():
1970
+ seq = torch.where(transfer_mask, candidate_seq, seq)
1971
+ text_step_idx += 1
1972
+
1973
+ if img_mask is not None and img_mask.any():
1974
+ logits_img = logits[..., image_vocab_start:image_vocab_end]
1975
+ probs_img = logits_img.softmax(dim=-1)
1976
+ sampled_img = torch.multinomial(
1977
+ probs_img.view(-1, codebook_size),
1978
+ 1,
1979
+ replacement=False
1980
+ ).view(*logits_img.shape[:2]) + image_vocab_start
1981
+ seq = torch.where(img_mask, sampled_img, seq)
1982
+
1983
+ if use_guidance:
1984
+ updated_mask = torch.zeros_like(seq, dtype=torch.bool)
1985
+ if img_mask is not None:
1986
+ updated_mask |= img_mask
1987
+ if text_mask is not None:
1988
+ updated_mask |= text_mask
1989
+ seq_uncond = torch.where(updated_mask, seq, seq_uncond)
1990
+
1991
+ # Decode text tokens from filled sequence
1992
+ pred_texts = []
1993
+ for row in seq:
1994
+ text_tokens = [int(t) for t in row.tolist() if 0 <= t < text_vocab_size]
1995
+ pred_texts.append(uni_prompting.text_tokenizer.decode(text_tokens, skip_special_tokens=True))
1996
+
1997
+ return seq, pred_texts
1998
+
1999
+
2000
  @torch.no_grad()
2001
  def mmu_generate_fast(self, idx=None, input_embeddings=None, max_new_tokens=128, steps=128,block_length=128, temperature=0.0, top_k=None, eot_token=None, cfg_scale=0.0, remasking='low_confidence', mask_id=126336, attention_mask=None):
2002
  """