Spaces:
Runtime error
Runtime error
Update modeling_metalatte.py
Browse files- modeling_metalatte.py +0 -11
modeling_metalatte.py
CHANGED
|
@@ -218,17 +218,6 @@ class MultitaskProteinModel(PreTrainedModel):
|
|
| 218 |
|
| 219 |
# Initialize weights and apply final processing
|
| 220 |
self.post_init()
|
| 221 |
-
|
| 222 |
-
@classmethod
|
| 223 |
-
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 224 |
-
config = kwargs.pop("config", None)
|
| 225 |
-
if config is None:
|
| 226 |
-
config = MetaLATTEConfig.from_pretrained(pretrained_model_name_or_path)
|
| 227 |
-
|
| 228 |
-
model = cls(config)
|
| 229 |
-
state_dict = torch.load(f"{pretrained_model_name_or_path}/pytorch_model.bin", map_location=torch.device('cpu'))['state_dict']
|
| 230 |
-
model.load_state_dict(state_dict, strict=False)
|
| 231 |
-
return model
|
| 232 |
|
| 233 |
def forward(self, input_ids, attention_mask=None):
|
| 234 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
|
|
| 218 |
|
| 219 |
# Initialize weights and apply final processing
|
| 220 |
self.post_init()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
|
| 222 |
def forward(self, input_ids, attention_mask=None):
|
| 223 |
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|