| import enum | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| class NeTIBatch: | |
| input_ids: torch.Tensor | |
| placeholder_token_id: int | |
| timesteps: torch.Tensor | |
| unet_layers: torch.Tensor | |
| truncation_idx: Optional[int] = None | |
| class PESigmas: | |
| sigma_t: float | |
| sigma_l: float | |