Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 9 additions & 14 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

import clip
import torch
import wandb
from cleanfid import fid
from composer import ComposerModel, Trainer
from composer.core import get_precision_context
from composer.loggers import LoggerDestination, WandBLogger
from composer.loggers import LoggerDestination
from composer.utils import dist
from torch.utils.data import DataLoader
from torchmetrics.multimodal import CLIPScore
Expand Down Expand Up @@ -91,19 +90,15 @@ def __init__(self,
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

# Init loggers
if self.loggers and dist.get_local_rank() == 0:
for logger in self.loggers:
if isinstance(logger, WandBLogger):
wandb.init(**logger._init_kwargs)

# Load the model
Trainer(model=self.model,
load_path=self.load_path,
load_weights_only=True,
load_strict_model_weights=load_strict_model_weights,
eval_dataloader=self.eval_dataloader,
seed=self.seed)
trainer = Trainer(model=self.model,
load_path=self.load_path,
load_weights_only=True,
load_strict_model_weights=load_strict_model_weights,
eval_dataloader=self.eval_dataloader,
seed=self.seed,
loggers=self.loggers)
self.trainer = trainer

# Move CLIP metric to device
self.device = dist.get_local_rank()
Expand Down