Skip to content

Commit c8d9cdc

Browse files
docs(mm): add readme for updating or adding new model support
1 parent e9c2411 commit c8d9cdc

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Model Management System
2+
3+
This document describes Invoke's model management system and common tasks for extending model support.
4+
5+
## Overview
6+
7+
The model management system handles the full lifecycle of models: identification, loading, and running. The system is extensible and supports multiple model architectures, formats, and quantization schemes.
8+
9+
### Three Major Subsystems
10+
11+
1. **Model Identification** (`configs/`): Determines model type, architecture, format, and metadata when users install models.
12+
2. **Model Loading** (`load/`): Loads models from disk into memory for inference.
13+
3. **Model Running**: Executes inference on loaded models. Implementation is scattered across the codebase, typically in architecture-specific inference code adjacent to `model_manager/`. The inference code is run in nodes in the graph execution system.
14+
15+
## Core Concepts
16+
17+
### Model Taxonomy
18+
19+
The `taxonomy.py` module defines the type system for models:
20+
21+
- `ModelType`: The kind of model (e.g., `Main`, `LoRA`, `ControlNet`, `VAE`).
22+
- `ModelFormat`: Storage format - may imply a quantization or some other quality (e.g., `Diffusers`, `Checkpoint`, `LyCORIS`, `BnbQuantizednf4b`).
23+
- `BaseModelType`: Associated pipeline architecture (e.g., `StableDiffusion1`, `StableDiffusionXL`, `Flux`). Models without an associated base use `Any` (e.g., `CLIPVision` is its own thing).
24+
- `ModelVariantType`, `FluxVariantType`, `ClipVariantType`: Architecture-specific variants.
25+
26+
These enums form a discriminated union that uniquely identifies each model configuration class.
27+
28+
### Model "Configs"
29+
30+
Model configs are Pydantic models that describe a model on disk. They include the model taxonomy, path, and any metadata needed for loading or running the model.
31+
32+
Model configs are stored in the database.
33+
34+
### Model Identification
35+
36+
When a user installs a model, the system attempts to identify it by trying each registered config class until one matches.
37+
38+
**Config Classes** (`configs/`):
39+
40+
- All config classes inherit from `Config_Base`, either directly or indirectly via some intermediary class (e.g., `Diffusers_Config_Base`, `Checkpoint_Config_Base`, or something narrower).
41+
- Each config class represents a specific, unique combination of `type`, `format`, `base`, and optional `variant`.
42+
- Config classes must implement `from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict) -> Self`. This method inspects the model on disk and raises `NotAMatchError` if the model doesn't match the config class, or returns an instance of the config class if it does.
43+
- `ModelOnDisk` is a helper class that abstracts the model weights. It should be the entrypoint for inspecting the model (e.g., loading state dicts).
44+
- Override fields allow users to provide hints (e.g., when differentiating between SD1/SD2/SDXL VAEs with identical structures).
45+
46+
**Identification Process**:
47+
48+
1. `ModelConfigFactory.from_model_on_disk()` is called with a path to the model.
49+
2. The factory iterates through all registered config classes, calling `from_model_on_disk()` on each.
50+
3. Each config class inspects the model (state dict keys, tensor shapes, config files, etc.).
51+
4. If a match is found, the config instance is returned. If multiple matches are found, they are prioritized (e.g., main models over LoRAs).
52+
5. If no match is found, an `Unknown_Config` is returned as a fallback.
53+
54+
**Utilities** (`identification_utils.py`):
55+
56+
- `NotAMatchError`: Exception raised when a model doesn't match a config class.
57+
- `get_config_dict_or_raise()`: Load JSON config files from diffusers/transformers models.
58+
- `raise_for_class_name()`: Validate class names in config files.
59+
- `raise_for_override_fields()`: Validate user-provided override fields against the config schema.
60+
- `state_dict_has_any_keys_*()`: Helpers for inspecting state dict keys.
61+
62+
### Model Loading
63+
64+
Model loaders handle instantiating models from disk into memory.
65+
66+
**Loader Classes** (`load/model_loaders/`):
67+
68+
- Loaders register themselves with a decorator `@ModelLoaderRegistry.register(base=..., type=..., format=...)`. The `type`, `format` and `base` indicate which configs classes the loader can handle.
69+
- Each loader implements `_load_model(self, config: AnyModelConfig, submodel_type: Optional[SubModelType]) -> AnyModel`.
70+
- Loaders are responsible for:
71+
- Loading model weights from the config's path.
72+
- Instantiating the correct model class (often using diffusers, transformers, or custom implementations).
73+
- Returning the in-memory model representation.
74+
75+
**Model Cache** (`load/model_cache/`):
76+
77+
> This system typically does not require changes to support new model types, but it is important to understand how it works.
78+
79+
- Manages models in memory with RAM and VRAM limits.
80+
- Handles moving models between CPU (storage device) and GPU (execution device).
81+
- Implements LRU eviction for RAM and smallest-first offload for VRAM.
82+
- Supports partial loading for large models on CUDA.
83+
- Thread-safe with locks on all public methods.
84+
85+
**Loading Process**:
86+
87+
1. The appropriate loader is selected based on the model config's `base`, `type`, and `format` attributes.
88+
2. The loader's `_load_model()` method is called with the model config.
89+
3. The loaded model is added to the model cache via `ModelCache.put()`.
90+
4. When needed, the model is moved into VRAM via `ModelCache.get()` and `ModelCache.lock()`.
91+
92+
### Model Running
93+
94+
Model running is architecture-specific and typically implemented in folders adjacent to `model_manager/`.
95+
96+
Inference code doesn't necessarily follow any specific pattern, and doesn't interact directly with the model management system except to receive model configs and loaded models.
97+
98+
At a high level, when a node needs to run a model, it will:
99+
100+
- Receive a model identifier as an input or constant. This is typically the model's database ID (aka the `key`).
101+
- The node will use the `InvocationContext` API to load the model. The request is dispatched to the model manager which will load the model and return the a model loader with a context manager that yields the in-memory model, mediating VRAM/RAM management as needed.
102+
- The node will run inference using the loaded model using whatever patterns or libraries it needs.
103+
104+
## Common Tasks
105+
106+
### Task 1: Improving Identification for a Supported Model Type
107+
108+
When identification fails or produces incorrect results for a model that should be supported, you may need to refine the identification logic.
109+
110+
**Steps**:
111+
112+
1. Obtain the failing model file or directory.
113+
2. Create a test case for it, following the instructions in `tests/model_identification/README.md`.
114+
3. Review the relevant config class in `configs/` (e.g., `configs/lora.py` for LoRA models).
115+
4. Examine the `from_model_on_disk()` method for some existing models to understand the patterns for identification logic.
116+
5. Inspect the failing model's files and structure:
117+
- For checkpoint files: Load the state dict and examine keys and tensor shapes.
118+
- For diffusers models: Examine the config files and directory structure.
119+
6. Update the identification logic to handle the new model variant. Common approaches:
120+
- Check for specific state dict keys or key patterns.
121+
- Inspect tensor shapes (e.g., `state_dict[key].shape`).
122+
- Parse config files for class names or configuration values.
123+
- Use helper functions from `identification_utils.py`.
124+
7. Run the test suite to verify the new logic works and doesn't break existing tests: `pytest tests/model_identification/test_identification.py`.
125+
- Make sure you have installed the test dependencies (e.g. `uv pip install -e ".[dev,test]"`).
126+
- If the model type is complex or has multiple variants, consider adding more test cases to cover edge cases.
127+
8. If, after successfully adding identification support for the model, it still doesn't work, you may need to update loading and/or inference code as well.
128+
129+
**Key Files**:
130+
131+
- Config class: `configs/<model_type>.py`
132+
- Identification utilities: `configs/identification_utils.py`
133+
- Taxonomy: `taxonomy.py`
134+
- Test README: `tests/model_identification/README.md`
135+
136+
### Task 2: Adding Support for a New Model Type
137+
138+
Adding a new model type requires implementing identification and loading logic. Inference and new nodes ("invocations") may be required if the model type doesn't fit into existing architectures or nodes.
139+
140+
**Steps**:
141+
142+
#### 1. Define Taxonomy
143+
144+
- Add a new `ModelType` enum value in `taxonomy.py` if needed.
145+
- Determine the appropriate `BaseModelType` (or use `Any` if not architecture-specific).
146+
- Add a new `ModelFormat` if the model uses a unique storage format.
147+
148+
You may need to add other attributes, depending on the model.
149+
150+
#### 2. Implement Config Class
151+
152+
- Create a new config file in `configs/` (e.g., `configs/new_model.py`).
153+
- Define a config class inheriting from `Config_Base` and appropriate format base class:
154+
- `Diffusers_Config_Base` for diffusers-style models.
155+
- `Checkpoint_Config_Base` for single-file checkpoint models.
156+
- Define `type`, `format`, and `base` as `Literal` fields with defaults. Remember, these must uniquely identify the config class.
157+
- Implement `from_model_on_disk()`:
158+
- Validate the model is the correct format (file vs directory).
159+
- Inspect state dict keys, tensor shapes, or config files.
160+
- Raise `NotAMatchError` if the model doesn't match.
161+
- Extract any additional metadata needed (e.g., variant, prediction type).
162+
- Return an instance of the config class.
163+
- Register the config in `configs/factory.py`:
164+
- Add the config class to the `AnyModelConfig` union.
165+
- Add an `Annotated[YourConfig, YourConfig.get_tag()]` entry.
166+
167+
#### 3. Implement Loader Class
168+
169+
- Create a new loader file in `load/model_loaders/` (e.g., `load/model_loaders/new_model.py`).
170+
- Define a loader class inheriting from `ModelLoader`.
171+
- Decorate with `@ModelLoaderRegistry.register(base=..., type=..., format=...)`.
172+
- Implement `_load_model()`:
173+
- Load model weights from `config.path`.
174+
- Instantiate the model using the appropriate library (diffusers, transformers, or custom).
175+
- Handle `submodel_type` if the model has submodels (e.g., text encoders, VAE).
176+
- Return the in-memory model representation.
177+
178+
#### 4. Add Tests
179+
180+
Follow the instructions in `tests/model_identification/README.md`.
181+
182+
#### 5. Implement Inference and Nodes (if needed)
183+
184+
- If the model type requires new inference logic, implement it in an appropriate location.
185+
- Create nodes for the model if it doesn't fit into existing nodes. Search for subclasses of `BaseInvocation` for many examples.
186+
187+
### 6. Frontend Support
188+
189+
#### Workflows tab
190+
191+
Typically, you will not need to do anything for the model to work in the Workflow Editor. When you define the node's model field, you can provide constraints for what type of models are selectable. The UI will automatically filter the list of models based on the model taxonomy.
192+
193+
For example, this field definition in a node will allow users to select only "main" (pipeline) Stable Diffusion 1.x or 2.x models:
194+
195+
```py
196+
model: ModelIdentifierField = InputField(
197+
ui_model_base=[BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2],
198+
ui_model_type=ModelType.Main,
199+
)
200+
```
201+
202+
This same pattern works for any combination of `type`, `base`, `format`, and `variant`.
203+
204+
#### Canvas / Generate tabs
205+
206+
The Canvas and Generate tabs use graphs internally, but they don't expose the full graph editor UI. Instead, they provide a simplified interface for common tasks.
207+
208+
They use "graph builder" functions, which take the user's selected settings and build a graph behind the scenes. We have one graph builder for each model architecture.
209+
210+
Updating or adding a graph builder can be a bit complex, and you'd likely need to update other UI components and state management to support the new model type.
211+
212+
The SDXL graph builder is a good example: `invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts`

0 commit comments

Comments
 (0)