autoencoder / register_autoencoder.py
amaye15's picture
Initial commit: AutoEncoder model
ced6e93
"""
Registration script for Autoencoder models with Hugging Face AutoModel framework.
"""
from transformers import AutoConfig, AutoModel
from configuration_autoencoder import AutoencoderConfig
from modeling_autoencoder import AutoencoderModel, AutoencoderForReconstruction
def register_autoencoder_models():
"""
Register the autoencoder models with the Hugging Face AutoModel framework.
This function registers:
- AutoencoderConfig with AutoConfig
- AutoencoderModel with AutoModel
- AutoencoderForReconstruction with AutoModel (for reconstruction tasks)
After calling this function, you can use:
- AutoConfig.from_pretrained() to load autoencoder configs
- AutoModel.from_pretrained() to load autoencoder models
"""
# Register configuration
AutoConfig.register("autoencoder", AutoencoderConfig)
# Register base model
AutoModel.register(AutoencoderConfig, AutoencoderModel)
# Note: For task-specific models like AutoencoderForReconstruction,
# we would typically create a custom AutoModelForReconstruction class
# and register it separately. For now, users can import directly.
print("✅ Autoencoder models registered with Hugging Face AutoModel framework!")
print("You can now use:")
print(" - AutoConfig.from_pretrained() for configs")
print(" - AutoModel.from_pretrained() for models")
print(" - Direct imports for task-specific models")
def register_for_auto_class():
"""
Register models for auto class functionality when saving/loading.
This enables the models to be automatically discovered when using
save_pretrained() and from_pretrained() methods.
"""
# Register config for auto class
AutoencoderConfig.register_for_auto_class()
# Register models for auto class
AutoencoderModel.register_for_auto_class("AutoModel")
AutoencoderForReconstruction.register_for_auto_class("AutoModel")
print("✅ Models registered for auto class functionality!")
if __name__ == "__main__":
# Register models when script is run directly
register_autoencoder_models()
register_for_auto_class()