@@ -1190,4 +1190,171 @@ struct FluxCLIPEmbedder : public Conditioner {
11901190 }
11911191};
11921192
1193+ struct SimpleT5Embedder : public Conditioner {
1194+ T5UniGramTokenizer t5_tokenizer;
1195+ std::shared_ptr<T5Runner> t5;
1196+
1197+ SimpleT5Embedder (ggml_backend_t backend,
1198+ std::map<std::string, enum ggml_type>& tensor_types,
1199+ int clip_skip = -1 ) {
1200+ t5 = std::make_shared<T5Runner>(backend, tensor_types, " text_encoders.t5xxl.transformer" );
1201+ }
1202+
1203+ void get_param_tensors (std::map<std::string, struct ggml_tensor *>& tensors) {
1204+ t5->get_param_tensors (tensors, " text_encoders.t5xxl.transformer" );
1205+ }
1206+
1207+ void alloc_params_buffer () {
1208+ t5->alloc_params_buffer ();
1209+ }
1210+
1211+ void free_params_buffer () {
1212+ t5->free_params_buffer ();
1213+ }
1214+
1215+ size_t get_params_buffer_size () {
1216+ size_t buffer_size = t5->get_params_buffer_size ();
1217+ return buffer_size;
1218+ }
1219+
1220+ std::pair<std::vector<int >, std::vector<float >> tokenize (std::string text,
1221+ size_t max_length = 0 ,
1222+ bool padding = false ) {
1223+ auto parsed_attention = parse_prompt_attention (text);
1224+
1225+ {
1226+ std::stringstream ss;
1227+ ss << " [" ;
1228+ for (const auto & item : parsed_attention) {
1229+ ss << " ['" << item.first << " ', " << item.second << " ], " ;
1230+ }
1231+ ss << " ]" ;
1232+ LOG_DEBUG (" parse '%s' to %s" , text.c_str (), ss.str ().c_str ());
1233+ }
1234+
1235+ auto on_new_token_cb = [&](std::string& str, std::vector<int32_t >& bpe_tokens) -> bool {
1236+ return false ;
1237+ };
1238+
1239+ std::vector<int > t5_tokens;
1240+ std::vector<float > t5_weights;
1241+ for (const auto & item : parsed_attention) {
1242+ const std::string& curr_text = item.first ;
1243+ float curr_weight = item.second ;
1244+
1245+ std::vector<int > curr_tokens = t5_tokenizer.Encode (curr_text, true );
1246+ t5_tokens.insert (t5_tokens.end (), curr_tokens.begin (), curr_tokens.end ());
1247+ t5_weights.insert (t5_weights.end (), curr_tokens.size (), curr_weight);
1248+ }
1249+
1250+ t5_tokenizer.pad_tokens (t5_tokens, t5_weights, max_length, padding);
1251+
1252+ // for (int i = 0; i < clip_l_tokens.size(); i++) {
1253+ // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", ";
1254+ // }
1255+ // std::cout << std::endl;
1256+
1257+ // for (int i = 0; i < t5_tokens.size(); i++) {
1258+ // std::cout << t5_tokens[i] << ":" << t5_weights[i] << ", ";
1259+ // }
1260+ // std::cout << std::endl;
1261+
1262+ return {t5_tokens, t5_weights};
1263+ }
1264+
1265+ SDCondition get_learned_condition_common (ggml_context* work_ctx,
1266+ int n_threads,
1267+ std::pair<std::vector<int >, std::vector<float >> token_and_weights,
1268+ int clip_skip,
1269+ bool force_zero_embeddings = false ) {
1270+ auto & t5_tokens = token_and_weights.first ;
1271+ auto & t5_weights = token_and_weights.second ;
1272+
1273+ int64_t t0 = ggml_time_ms ();
1274+ struct ggml_tensor * hidden_states = NULL ; // [N, n_token, 4096]
1275+ struct ggml_tensor * chunk_hidden_states = NULL ; // [n_token, 4096]
1276+ std::vector<float > hidden_states_vec;
1277+
1278+ size_t chunk_len = 256 ;
1279+ size_t chunk_count = t5_tokens.size () / chunk_len;
1280+ for (int chunk_idx = 0 ; chunk_idx < chunk_count; chunk_idx++) {
1281+ std::vector<int > chunk_tokens (t5_tokens.begin () + chunk_idx * chunk_len,
1282+ t5_tokens.begin () + (chunk_idx + 1 ) * chunk_len);
1283+ std::vector<float > chunk_weights (t5_weights.begin () + chunk_idx * chunk_len,
1284+ t5_weights.begin () + (chunk_idx + 1 ) * chunk_len);
1285+
1286+ auto input_ids = vector_to_ggml_tensor_i32 (work_ctx, chunk_tokens);
1287+
1288+ t5->compute (n_threads,
1289+ input_ids,
1290+ &chunk_hidden_states,
1291+ work_ctx);
1292+ {
1293+ auto tensor = chunk_hidden_states;
1294+ float original_mean = ggml_tensor_mean (tensor);
1295+ for (int i2 = 0 ; i2 < tensor->ne [2 ]; i2++) {
1296+ for (int i1 = 0 ; i1 < tensor->ne [1 ]; i1++) {
1297+ for (int i0 = 0 ; i0 < tensor->ne [0 ]; i0++) {
1298+ float value = ggml_tensor_get_f32 (tensor, i0, i1, i2);
1299+ value *= chunk_weights[i1];
1300+ ggml_tensor_set_f32 (tensor, value, i0, i1, i2);
1301+ }
1302+ }
1303+ }
1304+ float new_mean = ggml_tensor_mean (tensor);
1305+ ggml_tensor_scale (tensor, (original_mean / new_mean));
1306+ }
1307+
1308+ int64_t t1 = ggml_time_ms ();
1309+ LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
1310+ if (force_zero_embeddings) {
1311+ float * vec = (float *)chunk_hidden_states->data ;
1312+ for (int i = 0 ; i < ggml_nelements (chunk_hidden_states); i++) {
1313+ vec[i] = 0 ;
1314+ }
1315+ }
1316+
1317+ hidden_states_vec.insert (hidden_states_vec.end (),
1318+ (float *)chunk_hidden_states->data ,
1319+ ((float *)chunk_hidden_states->data ) + ggml_nelements (chunk_hidden_states));
1320+ }
1321+
1322+ hidden_states = vector_to_ggml_tensor (work_ctx, hidden_states_vec);
1323+ hidden_states = ggml_reshape_2d (work_ctx,
1324+ hidden_states,
1325+ chunk_hidden_states->ne [0 ],
1326+ ggml_nelements (hidden_states) / chunk_hidden_states->ne [0 ]);
1327+ return SDCondition (hidden_states, NULL , NULL );
1328+ }
1329+
1330+ SDCondition get_learned_condition (ggml_context* work_ctx,
1331+ int n_threads,
1332+ const std::string& text,
1333+ int clip_skip,
1334+ int width,
1335+ int height,
1336+ int adm_in_channels = -1 ,
1337+ bool force_zero_embeddings = false ) {
1338+ auto tokens_and_weights = tokenize (text, 256 , true );
1339+ return get_learned_condition_common (work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings);
1340+ }
1341+
1342+ std::tuple<SDCondition, std::vector<bool >> get_learned_condition_with_trigger (ggml_context* work_ctx,
1343+ int n_threads,
1344+ const std::string& text,
1345+ int clip_skip,
1346+ int width,
1347+ int height,
1348+ int num_input_imgs,
1349+ int adm_in_channels = -1 ,
1350+ bool force_zero_embeddings = false ) {
1351+ GGML_ASSERT (0 && " Not implemented yet!" );
1352+ }
1353+
1354+ std::string remove_trigger_from_prompt (ggml_context* work_ctx,
1355+ const std::string& prompt) {
1356+ GGML_ASSERT (0 && " Not implemented yet!" );
1357+ }
1358+ };
1359+
11931360#endif
0 commit comments