110 lines
3.1 KiB
Python
110 lines
3.1 KiB
Python
import base64
|
|
import glob
|
|
import hashlib
|
|
import json
|
|
import math
|
|
import os
|
|
import pickle
|
|
import time
|
|
from typing import List
|
|
|
|
from IPython.core.magic import (register_cell_magic, needs_local_scope)
|
|
from IPython import get_ipython
|
|
|
|
|
|
@register_cell_magic
|
|
@needs_local_scope
|
|
def skip_if_false(line, cell, local_ns=None):
|
|
condition_var = eval(line, None, local_ns)
|
|
if condition_var:
|
|
get_ipython().run_cell(cell)
|
|
return None
|
|
return f"Skipped (evaluated {line} to False)"
|
|
|
|
|
|
def left_abbreviate(string, limit=120):
|
|
return string if len(string) <= limit else f"...{string[:limit - 3]}"
|
|
|
|
|
|
class ProgressPrinter:
|
|
def __init__(self, verb, n) -> None:
|
|
self.verb = verb
|
|
self.n = n
|
|
self.i = 0
|
|
self.fmt_len = math.ceil(math.log10(n if n > 0 else 1))
|
|
|
|
def step(self, msg):
|
|
self.i += 1
|
|
print(f"({self.i:>{self.fmt_len}d}/{self.n}) {self.verb} {left_abbreviate(msg):<120}", end="\r")
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, exc_traceback):
|
|
self.i -= 1
|
|
|
|
if exc_value:
|
|
self.step("error.")
|
|
print()
|
|
print(exc_value)
|
|
return
|
|
|
|
self.step("done.")
|
|
print()
|
|
|
|
|
|
def stable_hash(obj):
|
|
return hashlib.md5(json.dumps(obj).encode("utf-8")).hexdigest()[:10]
|
|
|
|
|
|
def parse_as(type, string):
|
|
if any(issubclass(type, type2) for type2 in (str, bool, float, int)):
|
|
return type(string)
|
|
if issubclass(type, list) or issubclass(type, dict) or issubclass(type, set):
|
|
val = json.loads(string)
|
|
return type(val)
|
|
raise ValueError(f"Unknown type {type.__name__}")
|
|
|
|
def cached(name, function, file_deps: List[str]):
|
|
if not os.path.isdir("cache"):
|
|
os.makedirs("cache", exist_ok=True)
|
|
|
|
dep_time = 0.0
|
|
for file in file_deps:
|
|
# Get modified time of the current dependency
|
|
m_time = os.path.getmtime(file) if os.path.exists(file) else 0.
|
|
|
|
# Update dependency time to be the newest modified time of any dependency
|
|
if m_time > dep_time:
|
|
dep_time = m_time
|
|
|
|
# Check directories recursively to get the newest modified time
|
|
for root, dirs, files in os.walk(file):
|
|
for f in files + dirs:
|
|
filename = os.path.join(root, f)
|
|
m_time = os.path.getmtime(filename)
|
|
|
|
if m_time > dep_time:
|
|
dep_time = m_time
|
|
|
|
deps_hash = stable_hash(sorted(file_deps))
|
|
pkl_filename = f"cache/{name}_{deps_hash}.pkl"
|
|
|
|
pkl_time = os.path.getmtime(pkl_filename) if os.path.exists(pkl_filename) else 0.
|
|
|
|
if pkl_time > dep_time:
|
|
with open(pkl_filename, "rb") as f:
|
|
print(f"[CACHE] Found up-to-date cache entry ({pkl_filename}) for {name}, loading.")
|
|
return pickle.load(f)
|
|
|
|
if os.path.exists(pkl_filename):
|
|
print(f"[CACHE] Data dependencies for {name} changed, deleting cached version.")
|
|
os.remove(pkl_filename)
|
|
|
|
print(f"[CACHE] Creating cache entry for {name} (in {pkl_filename}).")
|
|
obj = function()
|
|
|
|
with open(pkl_filename, "wb") as f:
|
|
pickle.dump(obj, f)
|
|
|
|
return obj
|