f7ab4219ca
/treinar agora dispara em background thread e retorna imediato. Novo endpoint GET /treinar/status expõe estado (idle/running/concluido/vazio/erro).
214 lines
8.5 KiB
Python
214 lines
8.5 KiB
Python
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
|
from ultralytics import YOLO
|
|
import boto3
|
|
import os
|
|
import io
|
|
import shutil
|
|
import yaml
|
|
import threading
|
|
from PIL import Image
|
|
from PIL import Image as PILImage
|
|
from datetime import datetime
|
|
|
|
app = FastAPI()
|
|
s3 = boto3.client('s3')
|
|
BUCKET = os.getenv('BUCKET_NAME', 'ia-gondola-projeto-2024')
|
|
|
|
MODELOS_CARREGADOS = {}
|
|
|
|
_treinamento_status = {"status": "idle", "ambiente": None, "versao": None, "detalhe": None}
|
|
_status_lock = threading.Lock()
|
|
|
|
|
|
def redimensionar_imagem(caminho):
|
|
PILImage.MAX_IMAGE_PIXELS = None
|
|
img = PILImage.open(caminho)
|
|
w, h = img.size
|
|
max_px = 4096
|
|
if w > max_px or h > max_px:
|
|
ratio = min(max_px/w, max_px/h)
|
|
novo_w = int(w * ratio)
|
|
novo_h = int(h * ratio)
|
|
img = img.resize((novo_w, novo_h), PILImage.LANCZOS)
|
|
img.save(caminho, quality=95)
|
|
|
|
def log_print(msg):
|
|
carimbo = datetime.now().strftime("%H:%M:%S")
|
|
print(f"[{carimbo}] {msg}", flush=True)
|
|
|
|
def carregar_modelo_do_s3(ambiente: str):
|
|
s3_key = f"modelos/{ambiente}/atual/cerebro.pt"
|
|
local_path = f"/tmp/cerebro_{ambiente}.pt"
|
|
|
|
if os.path.exists(local_path):
|
|
data_local = os.path.getmtime(local_path)
|
|
resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key)
|
|
data_s3 = resp_head['LastModified'].timestamp()
|
|
if data_s3 > data_local:
|
|
log_print(f"Modelo '{ambiente}' atualizado no S3. Atualizando local...")
|
|
os.remove(local_path)
|
|
if ambiente in MODELOS_CARREGADOS:
|
|
del MODELOS_CARREGADOS[ambiente]
|
|
|
|
if ambiente not in MODELOS_CARREGADOS:
|
|
if not os.path.exists(local_path):
|
|
log_print(f"Baixando cerebro '{ambiente}'...")
|
|
s3.download_file(BUCKET, s3_key, local_path)
|
|
MODELOS_CARREGADOS[ambiente] = YOLO(local_path)
|
|
return MODELOS_CARREGADOS[ambiente]
|
|
|
|
|
|
def _executar_treino_sync(ambiente: str, pular_triagem: bool) -> dict:
|
|
"""Execução síncrona do treino — chamada em background thread."""
|
|
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
|
|
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
|
|
|
|
modelo_base = carregar_modelo_do_s3(ambiente)
|
|
|
|
if pular_triagem:
|
|
log_print(f"TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
|
|
if 'Contents' in objs:
|
|
for obj in objs['Contents']:
|
|
k = obj['Key']
|
|
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
|
|
new_k = k.replace("novos-treinamentos", "base-oficial")
|
|
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
|
|
s3.delete_object(Bucket=BUCKET, Key=k)
|
|
log_print(f"{os.path.basename(k)} -> base-oficial")
|
|
else:
|
|
log_print(f"INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
|
|
if 'Contents' in objs:
|
|
for obj in objs['Contents']:
|
|
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
|
|
img_key = obj['Key']
|
|
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
|
|
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
|
|
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
|
|
pred = modelo_base(img_data, conf=0.1, verbose=False)
|
|
confs = [float(b.conf) for r in pred for b in r.boxes]
|
|
media = sum(confs)/len(confs) if confs else 0
|
|
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
|
|
log_print(f"{os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
|
|
for k in [img_key, txt_key]:
|
|
try:
|
|
s3.head_object(Bucket=BUCKET, Key=k)
|
|
new_k = k.replace("novos-treinamentos", decisao)
|
|
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
|
|
s3.delete_object(Bucket=BUCKET, Key=k)
|
|
except Exception:
|
|
continue
|
|
|
|
dataset_local = f"/tmp/dataset_{ambiente}"
|
|
img_dir = f"{dataset_local}/train/images"
|
|
lbl_dir = f"{dataset_local}/train/labels"
|
|
if os.path.exists(dataset_local):
|
|
shutil.rmtree(dataset_local)
|
|
os.makedirs(img_dir, exist_ok=True)
|
|
os.makedirs(lbl_dir, exist_ok=True)
|
|
|
|
log_print("Baixando Base Oficial (Ouro)...")
|
|
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
|
|
if 'Contents' in ouro:
|
|
for o in ouro['Contents']:
|
|
k = o['Key']
|
|
if k.endswith(('.jpg', '.jpeg', '.txt')):
|
|
dest = img_dir if not k.endswith('.txt') else lbl_dir
|
|
caminho_local = os.path.join(dest, os.path.basename(k))
|
|
s3.download_file(BUCKET, k, caminho_local)
|
|
if not k.endswith('.txt'):
|
|
redimensionar_imagem(caminho_local)
|
|
|
|
if len(os.listdir(img_dir)) > 0:
|
|
yaml_path = f"{dataset_local}/data.yaml"
|
|
with open(yaml_path, 'w') as f:
|
|
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
|
|
|
|
log_print(f"Treinando com {len(os.listdir(img_dir))} fotos...")
|
|
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
|
|
|
|
best = "runs/detect/train/weights/best.pt"
|
|
if os.path.exists(best):
|
|
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
|
|
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
|
|
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
|
|
if ambiente in MODELOS_CARREGADOS:
|
|
del MODELOS_CARREGADOS[ambiente]
|
|
log_print("TREINAMENTO CONCLUIDO!")
|
|
return {"status": "sucesso", "versao": carimbo}
|
|
|
|
log_print("Nenhuma imagem passou na triagem para a Base Oficial.")
|
|
return {"status": "vazio"}
|
|
|
|
|
|
def _treinar_bg(ambiente: str, pular_triagem: bool):
|
|
global _treinamento_status
|
|
try:
|
|
resultado = _executar_treino_sync(ambiente, pular_triagem)
|
|
with _status_lock:
|
|
if resultado["status"] == "sucesso":
|
|
_treinamento_status = {
|
|
"status": "concluido",
|
|
"ambiente": ambiente,
|
|
"versao": resultado.get("versao"),
|
|
"detalhe": None,
|
|
}
|
|
else:
|
|
_treinamento_status = {
|
|
"status": "vazio",
|
|
"ambiente": ambiente,
|
|
"versao": None,
|
|
"detalhe": "Nenhuma imagem passou na triagem",
|
|
}
|
|
except Exception as e:
|
|
log_print(f"Erro no treino: {str(e)}")
|
|
with _status_lock:
|
|
_treinamento_status = {
|
|
"status": "erro",
|
|
"ambiente": ambiente,
|
|
"versao": None,
|
|
"detalhe": str(e)[:500],
|
|
}
|
|
|
|
|
|
@app.post("/detectar")
|
|
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
|
|
try:
|
|
Image.MAX_IMAGE_PIXELS = None
|
|
modelo = carregar_modelo_do_s3(ambiente)
|
|
conteudo = await file.read()
|
|
imagem = Image.open(io.BytesIO(conteudo))
|
|
results = modelo(imagem, conf=0.25)
|
|
deteccoes = [
|
|
{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)}
|
|
for r in results for b in r.boxes
|
|
]
|
|
return {"status": "sucesso", "deteccoes": deteccoes}
|
|
except Exception as e:
|
|
log_print(f"Erro deteccao: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@app.post("/treinar")
|
|
async def treinar(dados: dict):
|
|
global _treinamento_status
|
|
with _status_lock:
|
|
if _treinamento_status["status"] == "running":
|
|
raise HTTPException(status_code=409, detail="Treinamento já em andamento")
|
|
ambiente = dados.get("ambiente", "gondola")
|
|
pular_triagem = dados.get("pular_triagem", False)
|
|
_treinamento_status = {"status": "running", "ambiente": ambiente, "versao": None, "detalhe": None}
|
|
|
|
threading.Thread(target=_treinar_bg, args=(ambiente, pular_triagem), daemon=True).start()
|
|
return {"status": "iniciado", "ambiente": ambiente}
|
|
|
|
|
|
@app.get("/treinar/status")
|
|
async def status_treino():
|
|
with _status_lock:
|
|
return dict(_treinamento_status)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|