Skip to content

Commit 1898a10

Browse files
committed
modify deepseek load_weights() to align with main
Signed-off-by: qgai <qgai@nvidia.com>
1 parent 83cdb13 commit 1898a10

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def split(v, tp_size, idx, dim=0):
153153
return v
154154
if len(v.shape) == 1:
155155
return torch.chunk(v, tp_size)[idx].contiguous()
156-
else:
157-
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
156+
return torch.chunk(v, tp_size, dim=dim)[idx].contiguous()
158157

159158
def split_matrix_tp(v, tensor_parallel, rank, dim):
160159
return split(v, tensor_parallel, rank, dim=dim)
@@ -269,7 +268,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
269268

270269
for name, module in tqdm(all_named_modules.items(),
271270
desc="Loading weights"):
272-
if len(module._parameters) > 0:
271+
if len(module._parameters) <= 0 or name.startswith("draft_model"):
272+
continue
273+
else:
273274
names = name.split('.')
274275
parent_module_name = '.'.join(names[:-1])
275276
if "model.layers" in name and int(

0 commit comments

Comments
 (0)