优化性能及添加数据库ORM

This commit is contained in:
root 2024-04-13 23:43:49 +08:00
parent 83f1d53106
commit a099a5339b
3 changed files with 130 additions and 18 deletions

114
pmpt/DB.py Normal file
View File

@ -0,0 +1,114 @@
from tinydb import TinyDB, Storage, Query
# from . import util
import os
import warnings
from moyanlib import jsons
import io
from tqdm import tqdm
import zstandard
from typing import Dict, Any, Optional
class FastJSONStorage(Storage):
def __init__(
self,
path: str,
create_dirs=False,
encoding=None,
access_mode="rb+",
write_threshold=1,
**kwargs
):
super().__init__()
self.cctx = zstandard.ZstdCompressor()
self.dctx = zstandard.ZstdDecompressor()
self._mode = access_mode
self.kwargs = kwargs
self.write_threshold = write_threshold # 指定写入阈值
self.write_counter = 0
if any(
[character in self._mode for character in ("+", "w", "a")]
): # any of the writing modes
self.touch(path, create_dirs=create_dirs)
self._handle = open(path, mode=self._mode, encoding=encoding)
def close(self) -> None:
self._handle.close()
def touch(self, path: str, create_dirs: bool):
if create_dirs:
base_dir = os.path.dirname(path)
if not os.path.exists(base_dir):
os.makedirs(base_dir)
with open(path, "a"):
pass
def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
self._handle.seek(0, os.SEEK_END)
size = self._handle.tell()
if not size:
return None
else:
self._handle.seek(0)
data = self.dctx.decompress(self._handle.read())
return jsons.loads(data.decode())
def write(self, data: Dict[str, Dict[str, Any]]):
self.write_counter += 1
if self.write_counter == self.write_threshold:
self._handle.seek(0)
serialized = jsons.dumps(data, **self.kwargs)
serialized = self.cctx.compress(serialized.encode())
try:
self._handle.write(serialized)
except io.UnsupportedOperation:
raise IOError(
'Cannot write to the database. Access mode is "{0}"'.format(
self._mode
)
)
self._handle.flush()
os.fsync(self._handle.fileno())
self._handle.truncate()
self.write_counter = 0
class Base:
def __init__(self, name):
self.rootdb = TinyDB(
os.path.join(".", "DB.pcdb"), storage=FastJSONStorage, write_threshold=2
)
self.db = self.rootdb.table(name)
def insert(self, data):
self.db.insert(data)
def remove(self, query):
self.db.remove(query)
class PackageData(Base):
def __init__(self):
super(PackageData, self).__init__("PackageData")
self.query = Query()
def add(self, name, version, info=None):
self.db.insert({"Name": name, "Version": version, "info": info})
def get(self, name):
return self.db.search(self.query.Name == name)[0]
db = PackageData()
for i in tqdm(range(1145)):
db.add("json" + str(i), "v1" + str(i))
print(db.db.all())

View File

@ -5,6 +5,8 @@ import dill
import os import os
from .util import dirs, console from .util import dirs, console
packageNum = 0
def getSourceID(url): def getSourceID(url):
""" """
@ -26,7 +28,12 @@ class Index:
def getIndex(url): def getIndex(url):
global packageNum
try:
req = requests.get(url["url"]) # 请求HTML req = requests.get(url["url"]) # 请求HTML
except Exception:
console.print("[red]Unable to connect to source[/red]")
return False
HTMLIndex = req.text HTMLIndex = req.text
ClassIndex = Index(url) ClassIndex = Index(url)
@ -41,10 +48,12 @@ def getIndex(url):
ClassIndex.addPackage(package_name) # 添加包 ClassIndex.addPackage(package_name) # 添加包
console.print("Total number of packages:", str(ClassIndex.number)) console.print("Total number of packages:", str(ClassIndex.number))
packageNum += ClassIndex.number
console.print('📚 Saving index..."') console.print('📚 Saving index..."')
dill.dump( dill.dump(
ClassIndex, open(f"{dirs.user_data_dir}/Index/{getSourceID(url)}.pidx", "wb") ClassIndex, open(f"{dirs.user_data_dir}/Index/{getSourceID(url)}.pidx", "wb")
) )
return True
def getAllIndex(): def getAllIndex():
@ -63,5 +72,9 @@ def getAllIndex():
for source in SourceList: # 遍历源列表 for source in SourceList: # 遍历源列表
console.print("📚 Downloading index from", source["url"] + "...") console.print("📚 Downloading index from", source["url"] + "...")
getIndex(source) sta = getIndex(source)
if sta:
console.print("✅ [green]Index downloaded successfully![/green]") console.print("✅ [green]Index downloaded successfully![/green]")
else:
console.print("❌ [red]Index download failed.[/red]")
# print(packageNum)

View File

@ -24,21 +24,6 @@ def getVer(baseVar):
return baseVar return baseVar
def GlobalDecorator(frame, event, arg):
if event == "call":
try:
func_name = frame.f_code.co_name
module_name = frame.f_globals["__name__"]
package_name = module_name.split(".")[0] # 假设包名为模块名的第一部分
if package_name == "pmpt":
logger.trace(f"调用函数 {module_name}.{func_name}")
except:
pass
return GlobalDecorator
sys.settrace(GlobalDecorator)
logger.remove() logger.remove()
logger.add( logger.add(
os.path.join(dirs.user_data_dir, "log.log"), os.path.join(dirs.user_data_dir, "log.log"),