Update README.md
Browse files
README.md
CHANGED
|
@@ -249,209 +249,27 @@ Based on AltCLIP, we have also developed the AltDiffusion model, visualized as f
|
|
| 249 |

|
| 250 |
|
| 251 |
## 模型推理 Inference
|
| 252 |
-
|
| 253 |
```python
|
| 254 |
-
import torch
|
| 255 |
from PIL import Image
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 259 |
-
## 一行代码直接自动下载权重到'./checkpoints/clip-xlmr-large',并自动加载CLIP模型权重
|
| 260 |
-
## modelhub地址: Modelhub(https://model.baai.ac.cn/models)
|
| 261 |
-
loader = AutoLoader(
|
| 262 |
-
task_name="txt_img_matching",
|
| 263 |
-
model_dir="./checkpoints",
|
| 264 |
-
model_name="AltCLIP-XLMR-L"
|
| 265 |
-
)
|
| 266 |
-
## 获取加载好的模型
|
| 267 |
-
model = loader.get_model()
|
| 268 |
-
## 获取tokenizer
|
| 269 |
-
tokenizer = loader.get_tokenizer()
|
| 270 |
-
## 获取transform用来处理图像
|
| 271 |
-
transform = loader.get_transform()
|
| 272 |
-
|
| 273 |
-
model.eval()
|
| 274 |
-
model.to(device)
|
| 275 |
-
|
| 276 |
-
## 推理过程,图像与文本匹配
|
| 277 |
-
image = Image.open("./dog.jpeg")
|
| 278 |
-
image = transform(image)
|
| 279 |
-
image = torch.tensor(image["pixel_values"]).to(device)
|
| 280 |
-
text = tokenizer(["a rat", "a dog", "a cat"])["input_ids"]
|
| 281 |
-
|
| 282 |
-
text = torch.tensor(text).to(device)
|
| 283 |
-
|
| 284 |
-
with torch.no_grad():
|
| 285 |
-
image_features = model.get_image_features(image)
|
| 286 |
-
text_features = model.get_text_features(text)
|
| 287 |
-
text_probs = (image_features @ text_features.T).softmax(dim=-1)
|
| 288 |
-
|
| 289 |
-
print(text_probs.cpu().numpy()[0].tolist())
|
| 290 |
-
```
|
| 291 |
-
|
| 292 |
-
## CLIP微调 Finetuning
|
| 293 |
-
|
| 294 |
-
微调采用cifar10数据集,并使用FlagAI的Trainer快速开始训练过程。
|
| 295 |
-
|
| 296 |
-
Fine-tuning was done using the cifar10 dataset and using FlagAI's Trainer to quickly start the training process.
|
| 297 |
-
|
| 298 |
-
```python
|
| 299 |
-
# Copyright © 2022 BAAI. All rights reserved.
|
| 300 |
-
#
|
| 301 |
-
# Licensed under the Apache License, Version 2.0 (the "License")
|
| 302 |
-
import torch
|
| 303 |
-
from flagai.auto_model.auto_loader import AutoLoader
|
| 304 |
-
import os
|
| 305 |
-
from flagai.trainer import Trainer
|
| 306 |
-
from torchvision.datasets import (
|
| 307 |
-
CIFAR10
|
| 308 |
-
)
|
| 309 |
-
|
| 310 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 311 |
-
dataset_root = "./clip_benchmark_datasets"
|
| 312 |
-
dataset_name = "cifar10"
|
| 313 |
-
|
| 314 |
-
batch_size = 4
|
| 315 |
-
classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
|
| 316 |
-
|
| 317 |
-
auto_loader = AutoLoader(
|
| 318 |
-
task_name="txt_img_matching",
|
| 319 |
-
model_dir="./checkpoints/",
|
| 320 |
-
model_name="AltCLIP-XLMR-L" # Load the checkpoints from Modelhub(model.baai.ac.cn/models)
|
| 321 |
-
)
|
| 322 |
-
|
| 323 |
-
model = auto_loader.get_model()
|
| 324 |
-
model.to(device)
|
| 325 |
-
model.eval()
|
| 326 |
-
tokenizer = auto_loader.get_tokenizer()
|
| 327 |
-
transform = auto_loader.get_transform()
|
| 328 |
-
|
| 329 |
-
trainer = Trainer(env_type="pytorch",
|
| 330 |
-
pytorch_device=device,
|
| 331 |
-
experiment_name="clip_finetuning",
|
| 332 |
-
batch_size=4,
|
| 333 |
-
lr=1e-4,
|
| 334 |
-
epochs=10,
|
| 335 |
-
log_interval=10)
|
| 336 |
-
|
| 337 |
-
dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
|
| 338 |
-
transform=transform,
|
| 339 |
-
download=True)
|
| 340 |
-
|
| 341 |
-
def cifar10_collate_fn(batch):
|
| 342 |
-
# image shape is (batch, 3, 224, 224)
|
| 343 |
-
images = torch.tensor([b[0]["pixel_values"][0] for b in batch])
|
| 344 |
-
# text_id shape is (batch, n)
|
| 345 |
-
input_ids = torch.tensor([tokenizer(f"a photo of a {b[1]}",padding=True,truncation=True,max_length=77)["input_ids"] for b in batch])
|
| 346 |
-
|
| 347 |
-
return {
|
| 348 |
-
"pixel_values": images,
|
| 349 |
-
"input_ids": input_ids
|
| 350 |
-
}
|
| 351 |
-
|
| 352 |
-
if __name__ == "__main__":
|
| 353 |
-
trainer.train(model=model, train_dataset=dataset, collate_fn=cifar10_collate_fn)
|
| 354 |
-
```
|
| 355 |
-
|
| 356 |
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
-
|
|
|
|
|
|
|
| 359 |
|
| 360 |
-
|
|
|
|
| 361 |
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
We provide validation scripts that can be run directly on the cifar10 dataset.
|
| 365 |
-
|
| 366 |
-
```python
|
| 367 |
-
# Copyright © 2022 BAAI. All rights reserved.
|
| 368 |
-
#
|
| 369 |
-
# Licensed under the Apache License, Version 2.0 (the "License")
|
| 370 |
-
import torch
|
| 371 |
-
from flagai.auto_model.auto_loader import AutoLoader
|
| 372 |
-
from metrics import zeroshot_classification
|
| 373 |
-
import json
|
| 374 |
-
import os
|
| 375 |
-
from torchvision.datasets import CIFAR10
|
| 376 |
-
|
| 377 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 378 |
-
maxlen = 256
|
| 379 |
-
|
| 380 |
-
dataset_root = "./clip_benchmark_datasets"
|
| 381 |
-
dataset_name = "cifar10"
|
| 382 |
-
|
| 383 |
-
auto_loader = AutoLoader(
|
| 384 |
-
task_name="txt_img_matching",
|
| 385 |
-
model_dir="./checkpoints/",
|
| 386 |
-
model_name="AltCLIP-XLMR-L"
|
| 387 |
-
)
|
| 388 |
-
|
| 389 |
-
model = auto_loader.get_model()
|
| 390 |
-
model.to(device)
|
| 391 |
-
model.eval()
|
| 392 |
-
tokenizer = auto_loader.get_tokenizer()
|
| 393 |
-
transform = auto_loader.get_transform()
|
| 394 |
-
|
| 395 |
-
dataset = CIFAR10(root=os.path.join(dataset_root, dataset_name),
|
| 396 |
-
transform=transform,
|
| 397 |
-
download=True)
|
| 398 |
-
batch_size = 128
|
| 399 |
-
num_workers = 4
|
| 400 |
-
|
| 401 |
-
template = {"cifar10": [
|
| 402 |
-
"a photo of a {c}.",
|
| 403 |
-
"a blurry photo of a {c}.",
|
| 404 |
-
"a black and white photo of a {c}.",
|
| 405 |
-
"a low contrast photo of a {c}.",
|
| 406 |
-
"a high contrast photo of a {c}.",
|
| 407 |
-
"a bad photo of a {c}.",
|
| 408 |
-
"a good photo of a {c}.",
|
| 409 |
-
"a photo of a small {c}.",
|
| 410 |
-
"a photo of a big {c}.",
|
| 411 |
-
"a photo of the {c}.",
|
| 412 |
-
"a blurry photo of the {c}.",
|
| 413 |
-
"a black and white photo of the {c}.",
|
| 414 |
-
"a low contrast photo of the {c}.",
|
| 415 |
-
"a high contrast photo of the {c}.",
|
| 416 |
-
"a bad photo of the {c}.",
|
| 417 |
-
"a good photo of the {c}.",
|
| 418 |
-
"a photo of the small {c}.",
|
| 419 |
-
"a photo of the big {c}."
|
| 420 |
-
],
|
| 421 |
-
}
|
| 422 |
-
def evaluate():
|
| 423 |
-
if dataset:
|
| 424 |
-
dataloader = torch.utils.data.DataLoader(
|
| 425 |
-
dataset,
|
| 426 |
-
batch_size=batch_size,
|
| 427 |
-
shuffle=False,
|
| 428 |
-
num_workers=num_workers,
|
| 429 |
-
)
|
| 430 |
-
classnames = dataset.classes if hasattr(dataset, "classes") else None
|
| 431 |
-
|
| 432 |
-
zeroshot_templates = template["cifar10"]
|
| 433 |
-
metrics = zeroshot_classification.evaluate(
|
| 434 |
-
model,
|
| 435 |
-
dataloader,
|
| 436 |
-
tokenizer,
|
| 437 |
-
classnames,
|
| 438 |
-
zeroshot_templates,
|
| 439 |
-
device=device,
|
| 440 |
-
amp=True,
|
| 441 |
-
)
|
| 442 |
-
|
| 443 |
-
dump = {
|
| 444 |
-
"dataset": dataset_name,
|
| 445 |
-
"metrics": metrics
|
| 446 |
-
}
|
| 447 |
-
|
| 448 |
-
print(dump)
|
| 449 |
-
with open("./result.txt", "w") as f:
|
| 450 |
-
json.dump(dump, f)
|
| 451 |
-
return metrics
|
| 452 |
-
|
| 453 |
-
if __name__ == "__main__":
|
| 454 |
-
evaluate()
|
| 455 |
|
|
|
|
|
|
|
|
|
|
| 456 |
```
|
| 457 |
|
|
|
|
|
|
| 249 |

|
| 250 |
|
| 251 |
## 模型推理 Inference
|
| 252 |
+
Please download the code from [FlagAI AltCLIP](https://github.com/FlagAI-Open/FlagAI/tree/master/examples/AltCLIP)
|
| 253 |
```python
|
|
|
|
| 254 |
from PIL import Image
|
| 255 |
+
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
+
# transformers version >= 4.21.0
|
| 258 |
+
from modeling_altclip import AltCLIP
|
| 259 |
+
from processing_altclip import AltCLIPProcessor
|
| 260 |
|
| 261 |
+
# now our repo's in private, so we need `use_auth_token=True`
|
| 262 |
+
model = AltCLIP.from_pretrained("BAAI/AltCLIP")
|
| 263 |
+
processor = AltCLIPProcessor.from_pretrained("BAAI/AltCLIP")
|
| 264 |
|
| 265 |
+
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
| 266 |
+
image = Image.open(requests.get(url, stream=True).raw)
|
| 267 |
|
| 268 |
+
inputs = processor(text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
+
outputs = model(**inputs)
|
| 271 |
+
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
| 272 |
+
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
| 273 |
```
|
| 274 |
|
| 275 |
+
|