Skip to content

Commit 364d801

Browse files
committed
modify deepseek load_weights() to align with main
Signed-off-by: qgai <qgai@nvidia.com>
1 parent e6955d1 commit 364d801

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ def split(v, tp_size, idx, dim=0):
157157
return v
158158
if len(v.shape) == 1:
159159
return torch.chunk(v, tp_size)[idx].contiguous()
160-
else:
161-
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
160+
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
162161

163162
def split_matrix_tp(v, tensor_parallel, rank, dim):
164163
return split(v, tensor_parallel, rank, dim=dim)
@@ -273,7 +272,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
273272

274273
for name, module in tqdm(all_named_modules.items(),
275274
desc="Loading weights"):
276-
if len(module._parameters) > 0:
275+
if len(module._parameters) <= 0 or name.startswith("draft_model"):
276+
continue
277+
else:
277278
names = name.split('.')
278279
parent_module_name = '.'.join(names[:-1])
279280
if "model.layers" in name and int(

0 commit comments

Comments
 (0)