128 lines
6.0 KiB
Python
128 lines
6.0 KiB
Python
from django.core.management.base import BaseCommand
|
||
from openai import OpenAI
|
||
from huodong.models import Branch
|
||
import time
|
||
import re
|
||
|
||
|
||
class Command(BaseCommand):
|
||
help = 'Update branch provinces using AI API'
|
||
|
||
def handle(self, *args, **options):
|
||
client = OpenAI(
|
||
base_url="https://integrate.api.nvidia.com/v1",
|
||
api_key="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"
|
||
)
|
||
|
||
def get_correct_province(branch_name, current_location, retry_count=0):
|
||
prompt = f"""请根据以下分支机构名称,返回一个正确的中国省份名称(全称,如:浙江省、北京市、上海市、广东省等)。
|
||
|
||
分支机构名称:{branch_name}
|
||
|
||
要求:
|
||
1. 只返回省份名称,不要包含任何其他文字
|
||
2. 省份名称必须是中国的省级行政区全称
|
||
3. 如果无法确定,返回None
|
||
4. 不要被名称里含有的地名所影响
|
||
例子一:
|
||
分支机构名称:石家庄中山西路营业部
|
||
返回:河北省
|
||
例子二:
|
||
分支机构名称:北京市海淀区营业部
|
||
返回:北京市
|
||
例子三:
|
||
分支机构名称:奎屯北京东路营业部
|
||
返回:新疆维吾尔自治区
|
||
|
||
"""
|
||
|
||
|
||
try:
|
||
completion = client.chat.completions.create(
|
||
model="deepseek-ai/deepseek-r1",
|
||
messages=[{"role": "user", "content": prompt}],
|
||
temperature=0.3,
|
||
top_p=0.7,
|
||
max_tokens=100,
|
||
stream=False
|
||
)
|
||
|
||
message = completion.choices[0].message
|
||
reasoning = getattr(message, "reasoning_content", None)
|
||
content = message.content
|
||
|
||
if reasoning:
|
||
self.stdout.write(f" 推理过程: {reasoning[:200]}...")
|
||
|
||
if content is None:
|
||
if reasoning:
|
||
self.stdout.write(f" content为None,尝试从推理过程提取...")
|
||
|
||
patterns = [
|
||
r'(?:正确的(?:答案)?应该是|所以|因此|答案是|结果为)[::\s]*(北京市|天津市|上海市|重庆市|河北省|山西省|辽宁省|吉林省|黑龙江省|江苏省|浙江省|安徽省|福建省|江西省|山东省|河南省|湖北省|湖南省|广东省|海南省|四川省|贵州省|云南省|陕西省|甘肃省|青海省|台湾省|内蒙古自治区|广西壮族自治区|西藏自治区|宁夏回族自治区|新疆维吾尔自治区|香港特别行政区|澳门特别行政区)',
|
||
r'(北京市|天津市|上海市|重庆市|河北省|山西省|辽宁省|吉林省|黑龙江省|江苏省|浙江省|安徽省|福建省|江西省|山东省|河南省|湖北省|湖南省|广东省|海南省|四川省|贵州省|云南省|陕西省|甘肃省|青海省|台湾省|内蒙古自治区|广西壮族自治区|西藏自治区|宁夏回族自治区|新疆维吾尔自治区|香港特别行政区|澳门特别行政区)'
|
||
]
|
||
|
||
for pattern in patterns:
|
||
match = re.search(pattern, reasoning)
|
||
if match:
|
||
result = match.group(1)
|
||
self.stdout.write(f" 从推理提取结果: {result}")
|
||
return result
|
||
|
||
if retry_count < 2:
|
||
self.stdout.write(f" content为None,重试中 ({retry_count + 1}/2)...")
|
||
time.sleep(2)
|
||
return get_correct_province(branch_name, current_location, retry_count + 1)
|
||
else:
|
||
self.stdout.write(self.style.WARNING(f" 多次重试后仍无结果,使用当前省份"))
|
||
return current_location
|
||
|
||
result = content.strip()
|
||
self.stdout.write(f" 原始结果: {result}")
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
if retry_count < 2:
|
||
self.stdout.write(self.style.WARNING(f" API调用失败: {e},重试中 ({retry_count + 1}/2)..."))
|
||
time.sleep(2)
|
||
return get_correct_province(branch_name, current_location, retry_count + 1)
|
||
else:
|
||
self.stdout.write(self.style.ERROR(f" 多次重试后仍失败: {e},使用当前省份"))
|
||
return current_location
|
||
|
||
branches = Branch.objects.all()
|
||
total = branches.count()
|
||
updated_count = 0
|
||
|
||
self.stdout.write(self.style.SUCCESS(f"开始处理 {total} 个分支机构..."))
|
||
self.stdout.write("=" * 80)
|
||
|
||
for index, branch in enumerate(branches, 1):
|
||
self.stdout.write(f"\n[{index}/{total}] 处理: {branch.name}")
|
||
self.stdout.write(f"当前省份: {branch.location}")
|
||
|
||
try:
|
||
correct_province = get_correct_province(branch.name, branch.location)
|
||
self.stdout.write(f"建议省份: {correct_province}")
|
||
|
||
if correct_province != branch.location:
|
||
old_location = branch.location
|
||
branch.location = correct_province
|
||
branch.save()
|
||
self.stdout.write(self.style.SUCCESS(f"✓ 已更新: {old_location} -> {correct_province}"))
|
||
updated_count += 1
|
||
else:
|
||
self.stdout.write("- 省份未变化")
|
||
|
||
except Exception as e:
|
||
self.stdout.write(self.style.ERROR(f"✗ 处理失败: {e}"))
|
||
|
||
self.stdout.write("-" * 80)
|
||
|
||
self.stdout.write(self.style.SUCCESS(f"\n处理完成!"))
|
||
self.stdout.write(f"总计: {total} 个分支机构")
|
||
self.stdout.write(f"已更新: {updated_count} 个")
|
||
self.stdout.write(f"未变化: {total - updated_count} 个")
|