Skip to content

Commit c1f953f

Browse files
authored
Small bug fixes for running models without tokenizers (#168)
1 parent 2a6fff4 commit c1f953f

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

diffusion/callbacks/log_diffusion_images.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self,
9898
latent_batch = {}
9999
tokenized_t5 = t5_tokenizer(batch,
100100
padding='max_length',
101-
max_length=t5_tokenizer.model.max_length,
101+
max_length=t5_tokenizer.model_max_length,
102102
truncation=True,
103103
return_tensors='pt')
104104
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
@@ -108,7 +108,7 @@ def __init__(self,
108108

109109
tokenized_clip = clip_tokenizer(batch,
110110
padding='max_length',
111-
max_length=t5_tokenizer.model.max_length,
111+
max_length=clip_tokenizer.model_max_length,
112112
truncation=True,
113113
return_tensors='pt')
114114
clip_attention_mask = tokenized_clip['attention_mask'].cuda()

diffusion/train.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,20 @@ def train(config: DictConfig) -> None:
8888

8989
model: ComposerModel = hydra.utils.instantiate(config.model)
9090

91+
# If the model has a tokenizer, we'll need it for the dataset
92+
if hasattr(model, 'tokenizer'):
93+
tokenizer = model.tokenizer
94+
else:
95+
tokenizer = None
96+
9197
if hasattr(model, 'autoencoder_loss'):
9298
# Check if this is training an autoencoder. If so, the optimizer needs different param groups
9399
optimizer = make_autoencoder_optimizer(config, model)
94-
tokenizer = None
95100
elif isinstance(model, ComposerTextToImageMMDiT):
96101
# Check if this is training a transformer. If so, the optimizer needs different param groups
97102
optimizer = make_transformer_optimizer(config, model)
98-
tokenizer = model.tokenizer
99103
else:
100104
optimizer = hydra.utils.instantiate(config.optimizer, params=model.parameters())
101-
tokenizer = model.tokenizer
102105

103106
# Load train dataset. Currently this expects to load according to the datasetHparam method.
104107
# This means adding external datasets is currently not super easy. Will refactor or check for

0 commit comments

Comments
 (0)