Files
ia-gondola-engine/main.py
T
davi.dias f7ab4219ca feat: treino assíncrono com polling de status
/treinar agora dispara em background thread e retorna imediato.
Novo endpoint GET /treinar/status expõe estado (idle/running/concluido/vazio/erro).
2026-06-10 14:40:42 -03:00

214 lines
8.5 KiB
Python

from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from ultralytics import YOLO
import boto3
import os
import io
import shutil
import yaml
import threading
from PIL import Image
from PIL import Image as PILImage
from datetime import datetime
app = FastAPI()
s3 = boto3.client('s3')
BUCKET = os.getenv('BUCKET_NAME', 'ia-gondola-projeto-2024')
MODELOS_CARREGADOS = {}
_treinamento_status = {"status": "idle", "ambiente": None, "versao": None, "detalhe": None}
_status_lock = threading.Lock()
def redimensionar_imagem(caminho):
PILImage.MAX_IMAGE_PIXELS = None
img = PILImage.open(caminho)
w, h = img.size
max_px = 4096
if w > max_px or h > max_px:
ratio = min(max_px/w, max_px/h)
novo_w = int(w * ratio)
novo_h = int(h * ratio)
img = img.resize((novo_w, novo_h), PILImage.LANCZOS)
img.save(caminho, quality=95)
def log_print(msg):
carimbo = datetime.now().strftime("%H:%M:%S")
print(f"[{carimbo}] {msg}", flush=True)
def carregar_modelo_do_s3(ambiente: str):
s3_key = f"modelos/{ambiente}/atual/cerebro.pt"
local_path = f"/tmp/cerebro_{ambiente}.pt"
if os.path.exists(local_path):
data_local = os.path.getmtime(local_path)
resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key)
data_s3 = resp_head['LastModified'].timestamp()
if data_s3 > data_local:
log_print(f"Modelo '{ambiente}' atualizado no S3. Atualizando local...")
os.remove(local_path)
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
if ambiente not in MODELOS_CARREGADOS:
if not os.path.exists(local_path):
log_print(f"Baixando cerebro '{ambiente}'...")
s3.download_file(BUCKET, s3_key, local_path)
MODELOS_CARREGADOS[ambiente] = YOLO(local_path)
return MODELOS_CARREGADOS[ambiente]
def _executar_treino_sync(ambiente: str, pular_triagem: bool) -> dict:
"""Execução síncrona do treino — chamada em background thread."""
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
modelo_base = carregar_modelo_do_s3(ambiente)
if pular_triagem:
log_print(f"TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs:
for obj in objs['Contents']:
k = obj['Key']
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial")
else:
log_print(f"INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs:
for obj in objs['Contents']:
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
img_key = obj['Key']
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"{os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]:
try:
s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
except Exception:
continue
dataset_local = f"/tmp/dataset_{ambiente}"
img_dir = f"{dataset_local}/train/images"
lbl_dir = f"{dataset_local}/train/labels"
if os.path.exists(dataset_local):
shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
log_print("Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro:
for o in ouro['Contents']:
k = o['Key']
if k.endswith(('.jpg', '.jpeg', '.txt')):
dest = img_dir if not k.endswith('.txt') else lbl_dir
caminho_local = os.path.join(dest, os.path.basename(k))
s3.download_file(BUCKET, k, caminho_local)
if not k.endswith('.txt'):
redimensionar_imagem(caminho_local)
if len(os.listdir(img_dir)) > 0:
yaml_path = f"{dataset_local}/data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"Treinando com {len(os.listdir(img_dir))} fotos...")
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
best = "runs/detect/train/weights/best.pt"
if os.path.exists(best):
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
log_print("TREINAMENTO CONCLUIDO!")
return {"status": "sucesso", "versao": carimbo}
log_print("Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"}
def _treinar_bg(ambiente: str, pular_triagem: bool):
global _treinamento_status
try:
resultado = _executar_treino_sync(ambiente, pular_triagem)
with _status_lock:
if resultado["status"] == "sucesso":
_treinamento_status = {
"status": "concluido",
"ambiente": ambiente,
"versao": resultado.get("versao"),
"detalhe": None,
}
else:
_treinamento_status = {
"status": "vazio",
"ambiente": ambiente,
"versao": None,
"detalhe": "Nenhuma imagem passou na triagem",
}
except Exception as e:
log_print(f"Erro no treino: {str(e)}")
with _status_lock:
_treinamento_status = {
"status": "erro",
"ambiente": ambiente,
"versao": None,
"detalhe": str(e)[:500],
}
@app.post("/detectar")
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
try:
Image.MAX_IMAGE_PIXELS = None
modelo = carregar_modelo_do_s3(ambiente)
conteudo = await file.read()
imagem = Image.open(io.BytesIO(conteudo))
results = modelo(imagem, conf=0.25)
deteccoes = [
{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)}
for r in results for b in r.boxes
]
return {"status": "sucesso", "deteccoes": deteccoes}
except Exception as e:
log_print(f"Erro deteccao: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/treinar")
async def treinar(dados: dict):
global _treinamento_status
with _status_lock:
if _treinamento_status["status"] == "running":
raise HTTPException(status_code=409, detail="Treinamento já em andamento")
ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
_treinamento_status = {"status": "running", "ambiente": ambiente, "versao": None, "detalhe": None}
threading.Thread(target=_treinar_bg, args=(ambiente, pular_triagem), daemon=True).start()
return {"status": "iniciado", "ambiente": ambiente}
@app.get("/treinar/status")
async def status_treino():
with _status_lock:
return dict(_treinamento_status)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)