Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 57 additions & 37 deletions comfyui_to_python.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from nodes import NODE_CLASS_MAPPINGS, init_builtin_extra_nodes, init_external_custom_nodes
import copy
import glob
import inspect
Expand All @@ -11,10 +12,9 @@
import black


from utils import import_custom_nodes, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths, get_value_at_index
from utils import find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths, get_value_at_index

sys.path.append('../')
from nodes import NODE_CLASS_MAPPINGS


class FileHandler:
Expand Down Expand Up @@ -59,7 +59,8 @@ def read_json_file(file_path: str) -> dict:
# Format the list of JSON files as a string
json_files_str = "\n".join(json_files)

raise FileNotFoundError(f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}")
raise FileNotFoundError(
f"\n\nFile not found: {file_path}. JSON files in the directory:\n{json_files_str}")

except json.JSONDecodeError:
raise ValueError(f"Invalid JSON format in file: {file_path}")
Expand Down Expand Up @@ -153,11 +154,12 @@ def _load_special_functions_first(self) -> None:
"""
# Iterate over each key in the data to check for loader keys.
for key in self.data:
class_def = self.node_class_mappings[self.data[key]['class_type']]()
class_def = self.node_class_mappings[self.data[key]['class_type']](
)
# Check if the class is a loader class or meets specific conditions.
if (class_def.CATEGORY == 'loaders' or
class_def.FUNCTION in ['encode'] or
not any(isinstance(val, list) for val in self.data[key]['inputs'].values())):
if (class_def.CATEGORY == 'loaders' or
class_def.FUNCTION in ['encode'] or
not any(isinstance(val, list) for val in self.data[key]['inputs'].values())):
self.is_special_function = True
# If the key has not been visited, perform a DFS from that key.
if key not in self.visited:
Expand Down Expand Up @@ -195,7 +197,8 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
str: Generated execution code as a string.
"""
# Create the necessary data structures to hold imports and generated code
import_statements, executed_variables, special_functions_code, code = set(['NODE_CLASS_MAPPINGS']), {}, [], []
import_statements, executed_variables, special_functions_code, code = set(
['NODE_CLASS_MAPPINGS']), {}, [], []
# This dictionary will store the names of the objects that we have already initialized
initialized_objects = {}

Expand All @@ -213,34 +216,41 @@ def generate_workflow(self, load_order: List, filename: str = 'generated_code_wo
if class_type == 'PreviewImage':
continue

class_type, import_statement, class_code = self.get_class_info(class_type)
initialized_objects[class_type] = self.clean_variable_name(class_type)
class_type, import_statement, class_code = self.get_class_info(
class_type)
initialized_objects[class_type] = self.clean_variable_name(
class_type)
if class_type in self.base_node_class_mappings.keys():
import_statements.add(import_statement)
if class_type not in self.base_node_class_mappings.keys():
custom_nodes = True
special_functions_code.append(class_code)

# Get all possible parameters for class_def
class_def_params = self.get_function_parameters(getattr(class_def, class_def.FUNCTION))
class_def_params = self.get_function_parameters(
getattr(class_def, class_def.FUNCTION))

# Remove any keyword arguments from **inputs if they are not in class_def_params
inputs = {key: value for key, value in inputs.items() if key in class_def_params}
inputs = {key: value for key,
value in inputs.items() if key in class_def_params}
# Deal with hidden variables
if 'unique_id' in class_def_params:
inputs['unique_id'] = random.randint(1, 2**64)

# Create executed variable and generate code
executed_variables[idx] = f'{self.clean_variable_name(class_type)}_{idx}'
inputs = self.update_inputs(inputs, executed_variables)

if is_special_function:
special_functions_code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
special_functions_code.append(self.create_function_call_code(
initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
else:
code.append(self.create_function_call_code(initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))
code.append(self.create_function_call_code(
initialized_objects[class_type], class_def.FUNCTION, executed_variables[idx], is_special_function, **inputs))

# Generate final code by combining imports and code, and wrap them in a main function
final_code = self.assemble_python_code(import_statements, special_functions_code, code, queue_size, custom_nodes)
final_code = self.assemble_python_code(
import_statements, special_functions_code, code, queue_size, custom_nodes)

return final_code

Expand All @@ -257,7 +267,8 @@ def create_function_call_code(self, obj_name: str, func: str, variable_name: str
Returns:
str: The generated Python code.
"""
args = ', '.join(self.format_arg(key, value) for key, value in kwargs.items())
args = ', '.join(self.format_arg(key, value)
for key, value in kwargs.items())

# Generate the Python code
code = f'{variable_name} = {obj_name}.{func}({args})\n'
Expand Down Expand Up @@ -306,26 +317,30 @@ def assemble_python_code(self, import_statements: set, speical_functions_code: L
for func in [get_value_at_index, find_path, add_comfyui_directory_to_sys_path, add_extra_model_paths]:
func_strings.append(f'\n{inspect.getsource(func)}')
# Define static import statements required for the script
static_imports = ['import os', 'import random', 'import sys', 'from typing import Sequence, Mapping, Any, Union',
static_imports = ['import os', 'import random', 'import sys', 'from typing import Sequence, Mapping, Any, Union',
'import torch'] + func_strings + ['\n\nadd_comfyui_directory_to_sys_path()\nadd_extra_model_paths()\n']
# Check if custom nodes should be included
if custom_nodes:
static_imports.append(f'\n{inspect.getsource(import_custom_nodes)}\n')
custom_nodes = 'import_custom_nodes()\n\t'
static_imports.append(
f'\n{inspect.getsource(init_builtin_extra_nodes)}\n')
custom_nodes = 'init_builtin_extra_nodes()\n\t'
else:
custom_nodes = ''
# Create import statements for node classes
imports_code = [f"from nodes import {', '.join([class_name for class_name in import_statements])}" ]
imports_code = [
f"from nodes import {', '.join([class_name for class_name in import_statements])}"]
# Assemble the main function code, including custom nodes if applicable
main_function_code = "def main():\n\t" + f'{custom_nodes}with torch.inference_mode():\n\t\t' + '\n\t\t'.join(speical_functions_code) \
+ f'\n\n\t\tfor q in range({queue_size}):\n\t\t' + '\n\t\t'.join(code)
+ f'\n\n\t\tfor q in range({queue_size}):\n\t\t' + \
'\n\t\t'.join(code)
# Concatenate all parts to form the final code
final_code = '\n'.join(static_imports + imports_code + ['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
final_code = '\n'.join(static_imports + imports_code +
['', main_function_code, '', 'if __name__ == "__main__":', '\tmain()'])
# Format the final code according to PEP 8 using the Black library
final_code = black.format_str(final_code, mode=black.Mode())

return final_code

def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
"""Generates and returns necessary information about class type.

Expand All @@ -343,7 +358,7 @@ def get_class_info(self, class_type: str) -> Tuple[str, str, str]:
class_code = f'{variable_name} = NODE_CLASS_MAPPINGS["{class_type}"]()'

return class_type, import_statement, class_code

@staticmethod
def clean_variable_name(class_type: str) -> str:
"""
Expand All @@ -357,14 +372,14 @@ def clean_variable_name(class_type: str) -> str:
"""
# Convert to lowercase and replace spaces with underscores
clean_name = class_type.lower().strip().replace("-", "_").replace(" ", "_")

# Remove characters that are not letters, numbers, or underscores
clean_name = re.sub(r'[^a-z0-9_]', '', clean_name)

# Ensure that it doesn't start with a number
if clean_name[0].isdigit():
clean_name = "_" + clean_name

return clean_name

def get_function_parameters(self, func: Callable) -> List:
Expand All @@ -378,8 +393,8 @@ def get_function_parameters(self, func: Callable) -> List:
"""
signature = inspect.signature(func)
parameters = {name: param.default if param.default != param.empty else None
for name, param in signature.parameters.items()}
return list(parameters.keys())
for name, param in signature.parameters.items()}
return list(parameters.keys())

def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict:
"""Update inputs based on the executed variables.
Expand All @@ -393,9 +408,10 @@ def update_inputs(self, inputs: Dict, executed_variables: Dict) -> Dict:
"""
for key in inputs.keys():
if isinstance(inputs[key], list) and inputs[key][0] in executed_variables.keys():
inputs[key] = {'variable_name': f"get_value_at_index({executed_variables[inputs[key][0]]}, {inputs[key][1]})"}
inputs[key] = {
'variable_name': f"get_value_at_index({executed_variables[inputs[key][0]]}, {inputs[key][1]})"}
return inputs


class ComfyUItoPython:
"""Main workflow to generate Python code from a workflow_api.json file.
Expand Down Expand Up @@ -431,18 +447,21 @@ def execute(self):
None
"""
# Step 1: Import all custom nodes
import_custom_nodes()
init_builtin_extra_nodes()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using init_extra_nodes instead of init_builtin_extra_nodes, do you know what the difference is?


# Step 2: Read JSON data from the input file
data = FileHandler.read_json_file(self.input_file)

# Step 3: Determine the load order
load_order_determiner = LoadOrderDeterminer(data, self.node_class_mappings)
load_order_determiner = LoadOrderDeterminer(
data, self.node_class_mappings)
load_order = load_order_determiner.determine_load_order()

# Step 4: Generate the workflow code
code_generator = CodeGenerator(self.node_class_mappings, self.base_node_class_mappings)
generated_code = code_generator.generate_workflow(load_order, filename=self.output_file, queue_size=self.queue_size)
code_generator = CodeGenerator(
self.node_class_mappings, self.base_node_class_mappings)
generated_code = code_generator.generate_workflow(
load_order, filename=self.output_file, queue_size=self.queue_size)

# Step 5: Write the generated code to a file
FileHandler.write_code_to_file(self.output_file, generated_code)
Expand All @@ -457,4 +476,5 @@ def execute(self):
queue_size = 10

# Convert ComfyUI workflow to Python
ComfyUItoPython(input_file=input_file, output_file=output_file, queue_size=queue_size)
ComfyUItoPython(input_file=input_file,
output_file=output_file, queue_size=queue_size)