diff --git a/pmpt/DB.py b/pmpt/DB.py new file mode 100644 index 0000000..0dce35a --- /dev/null +++ b/pmpt/DB.py @@ -0,0 +1,114 @@ +from tinydb import TinyDB, Storage, Query + +# from . import util +import os +import warnings +from moyanlib import jsons +import io +from tqdm import tqdm +import zstandard +from typing import Dict, Any, Optional + + +class FastJSONStorage(Storage): + def __init__( + self, + path: str, + create_dirs=False, + encoding=None, + access_mode="rb+", + write_threshold=1, + **kwargs + ): + super().__init__() + self.cctx = zstandard.ZstdCompressor() + self.dctx = zstandard.ZstdDecompressor() + self._mode = access_mode + self.kwargs = kwargs + self.write_threshold = write_threshold # 指定写入阈值 + self.write_counter = 0 + if any( + [character in self._mode for character in ("+", "w", "a")] + ): # any of the writing modes + self.touch(path, create_dirs=create_dirs) + self._handle = open(path, mode=self._mode, encoding=encoding) + + def close(self) -> None: + self._handle.close() + + def touch(self, path: str, create_dirs: bool): + if create_dirs: + base_dir = os.path.dirname(path) + + if not os.path.exists(base_dir): + os.makedirs(base_dir) + + with open(path, "a"): + pass + + def read(self) -> Optional[Dict[str, Dict[str, Any]]]: + self._handle.seek(0, os.SEEK_END) + size = self._handle.tell() + + if not size: + return None + else: + self._handle.seek(0) + data = self.dctx.decompress(self._handle.read()) + return jsons.loads(data.decode()) + + def write(self, data: Dict[str, Dict[str, Any]]): + self.write_counter += 1 + if self.write_counter == self.write_threshold: + self._handle.seek(0) + + serialized = jsons.dumps(data, **self.kwargs) + serialized = self.cctx.compress(serialized.encode()) + + try: + self._handle.write(serialized) + except io.UnsupportedOperation: + raise IOError( + 'Cannot write to the database. Access mode is "{0}"'.format( + self._mode + ) + ) + + self._handle.flush() + os.fsync(self._handle.fileno()) + + self._handle.truncate() + self.write_counter = 0 + + +class Base: + def __init__(self, name): + self.rootdb = TinyDB( + os.path.join(".", "DB.pcdb"), storage=FastJSONStorage, write_threshold=2 + ) + self.db = self.rootdb.table(name) + + def insert(self, data): + self.db.insert(data) + + def remove(self, query): + self.db.remove(query) + + +class PackageData(Base): + def __init__(self): + super(PackageData, self).__init__("PackageData") + self.query = Query() + + def add(self, name, version, info=None): + self.db.insert({"Name": name, "Version": version, "info": info}) + + def get(self, name): + return self.db.search(self.query.Name == name)[0] + + +db = PackageData() +for i in tqdm(range(1145)): + db.add("json" + str(i), "v1" + str(i)) + +print(db.db.all()) diff --git a/pmpt/update.py b/pmpt/update.py index a06b892..d3c4936 100644 --- a/pmpt/update.py +++ b/pmpt/update.py @@ -5,6 +5,8 @@ import dill import os from .util import dirs, console +packageNum = 0 + def getSourceID(url): """ @@ -26,7 +28,12 @@ class Index: def getIndex(url): - req = requests.get(url["url"]) # 请求HTML + global packageNum + try: + req = requests.get(url["url"]) # 请求HTML + except Exception: + console.print("[red]Unable to connect to source[/red]") + return False HTMLIndex = req.text ClassIndex = Index(url) @@ -41,10 +48,12 @@ def getIndex(url): ClassIndex.addPackage(package_name) # 添加包 console.print("Total number of packages:", str(ClassIndex.number)) + packageNum += ClassIndex.number console.print('📚 Saving index..."') dill.dump( ClassIndex, open(f"{dirs.user_data_dir}/Index/{getSourceID(url)}.pidx", "wb") ) + return True def getAllIndex(): @@ -63,5 +72,9 @@ def getAllIndex(): for source in SourceList: # 遍历源列表 console.print("📚 Downloading index from", source["url"] + "...") - getIndex(source) - console.print("✅ [green]Index downloaded successfully![/green]") + sta = getIndex(source) + if sta: + console.print("✅ [green]Index downloaded successfully![/green]") + else: + console.print("❌ [red]Index download failed.[/red]") + # print(packageNum) diff --git a/pmpt/util.py b/pmpt/util.py index 6643d97..1ee1a48 100644 --- a/pmpt/util.py +++ b/pmpt/util.py @@ -24,21 +24,6 @@ def getVer(baseVar): return baseVar -def GlobalDecorator(frame, event, arg): - if event == "call": - try: - func_name = frame.f_code.co_name - module_name = frame.f_globals["__name__"] - package_name = module_name.split(".")[0] # 假设包名为模块名的第一部分 - if package_name == "pmpt": - logger.trace(f"调用函数 {module_name}.{func_name}") - except: - pass - return GlobalDecorator - - -sys.settrace(GlobalDecorator) - logger.remove() logger.add( os.path.join(dirs.user_data_dir, "log.log"),