Skip to content

Commit 53bc830

Browse files
chang-lfaradawn
authored andcommitted
[https://nvbugs/5549081][fix] Fix device id assignment for some vision models (NVIDIA#8070)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Faradawn Yang <faradawny@gmail.com>
1 parent e92d745 commit 53bc830

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

tensorrt_llm/_torch/models/modeling_hyperclovax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
726726
self.vision_config = self.pretrained_config.vision_config
727727

728728
model_path = self.pretrained_config._name_or_path
729-
self.device = f"cuda:{model_config.mapping.rank}"
729+
# TODO: use config.mapping.get_local_rank() instead
730+
self.device = f"cuda:{torch.cuda.current_device()}"
730731

731732
hf_model_config = AutoConfig.from_pretrained(model_path,
732733
trust_remote_code=True)

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,8 @@ def __init__(self, model_config: ModelConfig[Llama4Config], *args,
998998
**kwargs):
999999
super().__init__()
10001000
self.pretrained_config = model_config.pretrained_config
1001-
self.device = f"cuda:{model_config.mapping.rank}"
1001+
# TODO: use config.mapping.get_local_rank() instead
1002+
self.device = f"cuda:{torch.cuda.current_device()}"
10021003

10031004
self.dtype = self.pretrained_config.text_config.torch_dtype
10041005

tensorrt_llm/_torch/models/modeling_llava_next.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
295295
super().__init__()
296296
self.model_config = model_config
297297
self.pretrained_config = model_config.pretrained_config
298-
self.device = f"cuda:{model_config.mapping.rank}"
298+
# TODO: use config.mapping.get_local_rank() instead
299+
self.device = f"cuda:{torch.cuda.current_device()}"
299300
model_path = self.pretrained_config._name_or_path
300301

301302
# Determine the actual local path for model files

0 commit comments

Comments
 (0)