You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

148 lines
5.4 KiB

"""Miscellaneous training utilities: W&B helpers, dynamic imports, OmegaConf tools, and timers."""
import wandb
import importlib
import os
import time
from omegaconf import OmegaConf, DictConfig, ListConfig
def wandb_run_exists():
return isinstance(wandb.run, wandb.sdk.wandb_run.Run)
def import_type_from_str(s):
module_name, type_name = s.rsplit(".", 1)
module = importlib.import_module(module_name)
type_to_import = getattr(module, type_name)
return type_to_import
def recursive_set_struct(cfg, struct_value: bool):
OmegaConf.set_struct(cfg, struct_value)
if isinstance(cfg, DictConfig):
for key in cfg.keys():
try:
value = cfg[key]
if isinstance(value, (DictConfig, ListConfig)):
recursive_set_struct(value, struct_value)
except Exception as e:
# print(e)
pass
elif isinstance(cfg, ListConfig):
for item in cfg:
if isinstance(item, (DictConfig, ListConfig)):
recursive_set_struct(item, struct_value)
def get_filtered_state_dict(state_dict, state_dict_key):
"""
Filter state_dict keys that start with the given prefix and remove the prefix.
Args:
state_dict: Dictionary of state dict keys and values
state_dict_key: Prefix string to filter by
Returns:
Filtered dictionary with prefix removed from keys
"""
filtered_dict = {}
for key, value in state_dict.items():
if key.startswith(state_dict_key):
# Remove the prefix from the key
new_key = key[len(state_dict_key) :].lstrip(".")
filtered_dict[new_key] = value
return filtered_dict
def custom_instantiate(d, _resolve=True, _recursive=False, **add_kwargs):
"""
Recursively instantiate nested configs with _target_ fields.
"""
def _recursive_instantiate(obj):
# If it's a dict and has a _target_, instantiate it
if isinstance(obj, dict) and "_target_" in obj:
if obj.get("_recursive_", None) == True:
assert False, "recursive is not supported"
obj = obj.copy()
obj.pop("_recursive_", None)
obj.pop("_convert_", None)
obj.pop("_partial_", None)
_type = import_type_from_str(obj.pop("_target_"))
# Recursively instantiate all dict/list values
for k, v in list(obj.items()):
if isinstance(v, (dict, DictConfig)):
obj[k] = _recursive_instantiate(v)
elif isinstance(v, (list, ListConfig)):
obj[k] = [_recursive_instantiate(i) for i in v]
return _type(**obj)
# If it's a dict, recursively instantiate its values
elif isinstance(obj, dict):
return {k: _recursive_instantiate(v) for k, v in obj.items()}
# If it's a list, recursively instantiate its items
elif isinstance(obj, list):
return [_recursive_instantiate(i) for i in obj]
else:
return obj
# Top-level: allow add_kwargs to override
d = d.copy()
if isinstance(d, DictConfig):
if _resolve:
d = OmegaConf.to_container(d, resolve=_resolve)
else:
recursive_set_struct(d, False)
if d.get("_recursive_", None) == True:
assert False, "recursive is not supported"
d.pop("_recursive_", None)
d.pop("_convert_", None)
d.pop("_partial_", None)
_type = import_type_from_str(d.pop("_target_"))
if _recursive:
# Recursively instantiate all dict/list values
for k, v in list(d.items()):
if isinstance(v, (dict, DictConfig)):
d[k] = _recursive_instantiate(v)
elif isinstance(v, (list, ListConfig)):
d[k] = [_recursive_instantiate(i) for i in v]
return _type(**d, **add_kwargs)
# Global variable for timing indentation level
timer_indent_level = 0
# Context manager for timing
class Timer:
def __init__(self, name="", instance_enabled=True):
self.name = name
self.start_time = None
self.enabled = instance_enabled and os.environ.get("TIMER_ENABLED", "0") == "1"
if "LOCAL_RANK" in os.environ:
self.rank = int(os.environ["LOCAL_RANK"])
else:
self.rank = 0
self.show_rank = os.environ.get("TIMER_SHOW_RANK", "0") == "1"
self.rank_zero_only = os.environ.get("TIMER_RANK_ZERO_ONLY", "0") == "1"
def __enter__(self):
if (not self.enabled) or (self.rank_zero_only and self.rank != 0):
return self
global timer_indent_level
self.start_time = time.perf_counter()
self.current_indent = timer_indent_level # Capture current indent level
timer_indent_level += 1 # Increment global indent level for next call
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type:
return False # Re-raise the exception
if (not self.enabled) or (self.rank_zero_only and self.rank != 0):
return self
global timer_indent_level
elapsed_time = time.perf_counter() - self.start_time
indent = " " * self.current_indent # 4 spaces per indent level
rank_str = f"[rank{self.rank}] " if self.show_rank else ""
print(f"{indent}{rank_str}[{self.name}] time: {elapsed_time:.4f} seconds")
timer_indent_level -= 1 # Decrement global indent level after finishing