diff --git a/pr_review/src/config/init_config.py b/pr_review/src/config/init_config.py index 0fd1f0f770cbd18a0a8ed8c54764950f7625fd1f..f738e13c87e31bf9f24b688a79cf0a649ac66aee 100644 --- a/pr_review/src/config/init_config.py +++ b/pr_review/src/config/init_config.py @@ -1,24 +1,17 @@ +import os import yaml from gitee.gitee_api import GiteeCaller from gpt.bot import Gpt +config = {} -def init_config(path): - with open(path, "r", encoding="utf-8") as f: - config = yaml.safe_load(f) +path = os.getenv("APPLICATION_PATH", "config.yaml") +with open(path, "r", encoding="utf-8") as f: + config = yaml.safe_load(f) +def init_wokers(): GiteeCaller.init_config_attr(config["access_token"], config["gitee_host"]) - - Gpt.init_config_attr( - config["gpt"]["max_token_length"], - config["gpt"]["encoding_name"], - config["gpt"]["url"], - config["gpt"]["limit"], - config["gpt"]["prompt"], - ) - - - + Gpt.init_config_attr(config) diff --git a/pr_review/src/gpt/bot.py b/pr_review/src/gpt/bot.py index d520d2a12195e52ace4a75f6cddb914445a17af1..d5cd1f6ca9bb994b87bbcb8a059a189f8d416bb0 100644 --- a/pr_review/src/gpt/bot.py +++ b/pr_review/src/gpt/bot.py @@ -10,13 +10,21 @@ class Gpt: url = '' limit = 5 prompt = '' + auth_url = '' + app_id = '' + app_secret = '' - def init_config_attr(max_token_length, encoding_name, url, limit, prompt): - Gpt.max_token_length = max_token_length - Gpt.encoding_name = encoding_name - Gpt.url = url - Gpt.limit = limit - Gpt.prompt = prompt + @staticmethod + def init_config_attr(config): + Gpt.max_token_length = config["gpt"]["max_token_length"] + Gpt.encoding_name = config["gpt"]["encoding_name"] + Gpt.url = config["gpt"]["url"] + Gpt.limit = config["gpt"]["limit"] + Gpt.prompt = config["gpt"]["prompt"] + + Gpt.auth_url = config["auth"]["auth_url"] + Gpt.app_id = config["auth"]["app_id"] + Gpt.app_secret = config["auth"]["app_secret"] class Bot(Gpt): @@ -107,16 +115,30 @@ class Bot(Gpt): LGTM ''' + + def get_token(self): + params = { + 'grant_type': 'secret1', + 'app_id': self.app_id, + 'app_secret': self.app_secret + } + try: + resp = requests.get(url=self.auth_url, params=params) + data = resp.json() + token = data.get('token') + return token + except Exception as e: + logger.error(f"Failed to get token: {e}") + return None + def chat(self, prompt): data = { - "model": "gpt-3.5-turbo", "messages": [ { "role": "user", "content": prompt } - ], - "temperature": 0.7, + ] } response = requests.post( self.url, json=data @@ -127,6 +149,37 @@ class Bot(Gpt): return response + def stream_chat(self, prompt): + token = self.get_token() + if not token: + logger.error(f"Failed to get token") + return + data = { + "messages": [ + { + "role": "user", + "content": prompt + } + ] + } + header = {'Authorization': token} + response = requests.post( + self.url, json=data, headers=header + ) + if response.status_code != 200: + logger.info("get answer error") + logger.info(response.text) + + response = requests.post( + self.url, json=data, headers=header + ) + resp = '' + for res in response.iter_lines(): + item = res.decode('utf-8') + answer = json.loads(item.split('data:')[-1]).get('answer') + resp += answer + return resp + def get_token_count(self, content): encoding = tiktoken.get_encoding(self.encoding_name) tokens = encoding.encode(content) diff --git a/pr_review/src/main.py b/pr_review/src/main.py index 47e888c4e62d3868455f90412c16642ffe059616..7247009f8eb3d8a7a26391e3bb4ef9fe91e08a3f 100644 --- a/pr_review/src/main.py +++ b/pr_review/src/main.py @@ -1,11 +1,8 @@ -import os from router import router from config import init_config - def main(): - path = os.getenv("APPLICATION_PATH") - init_config.init_config(path) + init_config.init_wokers() router.start_router() diff --git a/pr_review/src/review_code/options.py b/pr_review/src/review_code/options.py index 591c5927ab2f9702bb29c102eaa6d0b134be5212..d6aef6d6852b7daccac9bd662b574a786a18d19c 100644 --- a/pr_review/src/review_code/options.py +++ b/pr_review/src/review_code/options.py @@ -1,14 +1,23 @@ import fnmatch - +from config.init_config import config class Options: def __init__(self): self.maxFiles = 1000 self.pathFilters= '' self.TokenLimits = 1024 - self.rules = {'*.py': False, '*.java': False, '*.js': False, '*.cpp': False,} + self.rules = self.init_rules() self.debug = False + def init_rules(self): + rules_dic = {} + if not config.get('rules'): + return rules_dic + rules = config.get('rules').split(',') + for rule in rules: + rules_dic['*.' + rule] = False + return rules_dic + def checkPath(self, path): # 默认所有路径均符合rules,值为True的路径才是过滤掉的路径 if len(self.rules) == 0: