File tree Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Expand file tree Collapse file tree 2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -98,7 +98,7 @@ def __init__(self,
9898 latent_batch = {}
9999 tokenized_t5 = t5_tokenizer (batch ,
100100 padding = 'max_length' ,
101- max_length = t5_tokenizer .model . max_length ,
101+ max_length = t5_tokenizer .model_max_length ,
102102 truncation = True ,
103103 return_tensors = 'pt' )
104104 t5_attention_mask = tokenized_t5 ['attention_mask' ].to (torch .bool ).cuda ()
@@ -108,7 +108,7 @@ def __init__(self,
108108
109109 tokenized_clip = clip_tokenizer (batch ,
110110 padding = 'max_length' ,
111- max_length = t5_tokenizer . model . max_length ,
111+ max_length = clip_tokenizer . model_max_length ,
112112 truncation = True ,
113113 return_tensors = 'pt' )
114114 clip_attention_mask = tokenized_clip ['attention_mask' ].cuda ()
Original file line number Diff line number Diff line change @@ -88,17 +88,20 @@ def train(config: DictConfig) -> None:
8888
8989 model : ComposerModel = hydra .utils .instantiate (config .model )
9090
91+ # If the model has a tokenizer, we'll need it for the dataset
92+ if hasattr (model , 'tokenizer' ):
93+ tokenizer = model .tokenizer
94+ else :
95+ tokenizer = None
96+
9197 if hasattr (model , 'autoencoder_loss' ):
9298 # Check if this is training an autoencoder. If so, the optimizer needs different param groups
9399 optimizer = make_autoencoder_optimizer (config , model )
94- tokenizer = None
95100 elif isinstance (model , ComposerTextToImageMMDiT ):
96101 # Check if this is training a transformer. If so, the optimizer needs different param groups
97102 optimizer = make_transformer_optimizer (config , model )
98- tokenizer = model .tokenizer
99103 else :
100104 optimizer = hydra .utils .instantiate (config .optimizer , params = model .parameters ())
101- tokenizer = model .tokenizer
102105
103106 # Load train dataset. Currently this expects to load according to the datasetHparam method.
104107 # This means adding external datasets is currently not super easy. Will refactor or check for
You can’t perform that action at this time.
0 commit comments