优化性能及添加数据库ORM
This commit is contained in:
parent
83f1d53106
commit
a099a5339b
114
pmpt/DB.py
Normal file
114
pmpt/DB.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
from tinydb import TinyDB, Storage, Query
|
||||||
|
|
||||||
|
# from . import util
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from moyanlib import jsons
|
||||||
|
import io
|
||||||
|
from tqdm import tqdm
|
||||||
|
import zstandard
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class FastJSONStorage(Storage):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
create_dirs=False,
|
||||||
|
encoding=None,
|
||||||
|
access_mode="rb+",
|
||||||
|
write_threshold=1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.cctx = zstandard.ZstdCompressor()
|
||||||
|
self.dctx = zstandard.ZstdDecompressor()
|
||||||
|
self._mode = access_mode
|
||||||
|
self.kwargs = kwargs
|
||||||
|
self.write_threshold = write_threshold # 指定写入阈值
|
||||||
|
self.write_counter = 0
|
||||||
|
if any(
|
||||||
|
[character in self._mode for character in ("+", "w", "a")]
|
||||||
|
): # any of the writing modes
|
||||||
|
self.touch(path, create_dirs=create_dirs)
|
||||||
|
self._handle = open(path, mode=self._mode, encoding=encoding)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._handle.close()
|
||||||
|
|
||||||
|
def touch(self, path: str, create_dirs: bool):
|
||||||
|
if create_dirs:
|
||||||
|
base_dir = os.path.dirname(path)
|
||||||
|
|
||||||
|
if not os.path.exists(base_dir):
|
||||||
|
os.makedirs(base_dir)
|
||||||
|
|
||||||
|
with open(path, "a"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def read(self) -> Optional[Dict[str, Dict[str, Any]]]:
|
||||||
|
self._handle.seek(0, os.SEEK_END)
|
||||||
|
size = self._handle.tell()
|
||||||
|
|
||||||
|
if not size:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
self._handle.seek(0)
|
||||||
|
data = self.dctx.decompress(self._handle.read())
|
||||||
|
return jsons.loads(data.decode())
|
||||||
|
|
||||||
|
def write(self, data: Dict[str, Dict[str, Any]]):
|
||||||
|
self.write_counter += 1
|
||||||
|
if self.write_counter == self.write_threshold:
|
||||||
|
self._handle.seek(0)
|
||||||
|
|
||||||
|
serialized = jsons.dumps(data, **self.kwargs)
|
||||||
|
serialized = self.cctx.compress(serialized.encode())
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._handle.write(serialized)
|
||||||
|
except io.UnsupportedOperation:
|
||||||
|
raise IOError(
|
||||||
|
'Cannot write to the database. Access mode is "{0}"'.format(
|
||||||
|
self._mode
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._handle.flush()
|
||||||
|
os.fsync(self._handle.fileno())
|
||||||
|
|
||||||
|
self._handle.truncate()
|
||||||
|
self.write_counter = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Base:
|
||||||
|
def __init__(self, name):
|
||||||
|
self.rootdb = TinyDB(
|
||||||
|
os.path.join(".", "DB.pcdb"), storage=FastJSONStorage, write_threshold=2
|
||||||
|
)
|
||||||
|
self.db = self.rootdb.table(name)
|
||||||
|
|
||||||
|
def insert(self, data):
|
||||||
|
self.db.insert(data)
|
||||||
|
|
||||||
|
def remove(self, query):
|
||||||
|
self.db.remove(query)
|
||||||
|
|
||||||
|
|
||||||
|
class PackageData(Base):
|
||||||
|
def __init__(self):
|
||||||
|
super(PackageData, self).__init__("PackageData")
|
||||||
|
self.query = Query()
|
||||||
|
|
||||||
|
def add(self, name, version, info=None):
|
||||||
|
self.db.insert({"Name": name, "Version": version, "info": info})
|
||||||
|
|
||||||
|
def get(self, name):
|
||||||
|
return self.db.search(self.query.Name == name)[0]
|
||||||
|
|
||||||
|
|
||||||
|
db = PackageData()
|
||||||
|
for i in tqdm(range(1145)):
|
||||||
|
db.add("json" + str(i), "v1" + str(i))
|
||||||
|
|
||||||
|
print(db.db.all())
|
|
@ -5,6 +5,8 @@ import dill
|
||||||
import os
|
import os
|
||||||
from .util import dirs, console
|
from .util import dirs, console
|
||||||
|
|
||||||
|
packageNum = 0
|
||||||
|
|
||||||
|
|
||||||
def getSourceID(url):
|
def getSourceID(url):
|
||||||
"""
|
"""
|
||||||
|
@ -26,7 +28,12 @@ class Index:
|
||||||
|
|
||||||
|
|
||||||
def getIndex(url):
|
def getIndex(url):
|
||||||
req = requests.get(url["url"]) # 请求HTML
|
global packageNum
|
||||||
|
try:
|
||||||
|
req = requests.get(url["url"]) # 请求HTML
|
||||||
|
except Exception:
|
||||||
|
console.print("[red]Unable to connect to source[/red]")
|
||||||
|
return False
|
||||||
HTMLIndex = req.text
|
HTMLIndex = req.text
|
||||||
|
|
||||||
ClassIndex = Index(url)
|
ClassIndex = Index(url)
|
||||||
|
@ -41,10 +48,12 @@ def getIndex(url):
|
||||||
ClassIndex.addPackage(package_name) # 添加包
|
ClassIndex.addPackage(package_name) # 添加包
|
||||||
|
|
||||||
console.print("Total number of packages:", str(ClassIndex.number))
|
console.print("Total number of packages:", str(ClassIndex.number))
|
||||||
|
packageNum += ClassIndex.number
|
||||||
console.print('📚 Saving index..."')
|
console.print('📚 Saving index..."')
|
||||||
dill.dump(
|
dill.dump(
|
||||||
ClassIndex, open(f"{dirs.user_data_dir}/Index/{getSourceID(url)}.pidx", "wb")
|
ClassIndex, open(f"{dirs.user_data_dir}/Index/{getSourceID(url)}.pidx", "wb")
|
||||||
)
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def getAllIndex():
|
def getAllIndex():
|
||||||
|
@ -63,5 +72,9 @@ def getAllIndex():
|
||||||
|
|
||||||
for source in SourceList: # 遍历源列表
|
for source in SourceList: # 遍历源列表
|
||||||
console.print("📚 Downloading index from", source["url"] + "...")
|
console.print("📚 Downloading index from", source["url"] + "...")
|
||||||
getIndex(source)
|
sta = getIndex(source)
|
||||||
console.print("✅ [green]Index downloaded successfully![/green]")
|
if sta:
|
||||||
|
console.print("✅ [green]Index downloaded successfully![/green]")
|
||||||
|
else:
|
||||||
|
console.print("❌ [red]Index download failed.[/red]")
|
||||||
|
# print(packageNum)
|
||||||
|
|
15
pmpt/util.py
15
pmpt/util.py
|
@ -24,21 +24,6 @@ def getVer(baseVar):
|
||||||
return baseVar
|
return baseVar
|
||||||
|
|
||||||
|
|
||||||
def GlobalDecorator(frame, event, arg):
|
|
||||||
if event == "call":
|
|
||||||
try:
|
|
||||||
func_name = frame.f_code.co_name
|
|
||||||
module_name = frame.f_globals["__name__"]
|
|
||||||
package_name = module_name.split(".")[0] # 假设包名为模块名的第一部分
|
|
||||||
if package_name == "pmpt":
|
|
||||||
logger.trace(f"调用函数 {module_name}.{func_name}")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return GlobalDecorator
|
|
||||||
|
|
||||||
|
|
||||||
sys.settrace(GlobalDecorator)
|
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
logger.add(
|
logger.add(
|
||||||
os.path.join(dirs.user_data_dir, "log.log"),
|
os.path.join(dirs.user_data_dir, "log.log"),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user