feat: treino assíncrono com polling de status

/treinar agora dispara em background thread e retorna imediato.
Novo endpoint GET /treinar/status expõe estado (idle/running/concluido/vazio/erro).
This commit is contained in:
2026-06-10 14:40:42 -03:00
parent 8bdb28cdc1
commit f7ab4219ca
+95 -43
View File
@@ -5,6 +5,7 @@ import os
import io import io
import shutil import shutil
import yaml import yaml
import threading
from PIL import Image from PIL import Image
from PIL import Image as PILImage from PIL import Image as PILImage
from datetime import datetime from datetime import datetime
@@ -15,6 +16,10 @@ BUCKET = os.getenv('BUCKET_NAME', 'ia-gondola-projeto-2024')
MODELOS_CARREGADOS = {} MODELOS_CARREGADOS = {}
_treinamento_status = {"status": "idle", "ambiente": None, "versao": None, "detalhe": None}
_status_lock = threading.Lock()
def redimensionar_imagem(caminho): def redimensionar_imagem(caminho):
PILImage.MAX_IMAGE_PIXELS = None PILImage.MAX_IMAGE_PIXELS = None
img = PILImage.open(caminho) img = PILImage.open(caminho)
@@ -40,44 +45,28 @@ def carregar_modelo_do_s3(ambiente: str):
resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key) resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key)
data_s3 = resp_head['LastModified'].timestamp() data_s3 = resp_head['LastModified'].timestamp()
if data_s3 > data_local: if data_s3 > data_local:
log_print(f"🔄 Modelo '{ambiente}' atualizado no S3. Atualizando local...") log_print(f"Modelo '{ambiente}' atualizado no S3. Atualizando local...")
os.remove(local_path) os.remove(local_path)
if ambiente in MODELOS_CARREGADOS: if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente] del MODELOS_CARREGADOS[ambiente]
if ambiente not in MODELOS_CARREGADOS: if ambiente not in MODELOS_CARREGADOS:
if not os.path.exists(local_path): if not os.path.exists(local_path):
log_print(f"📥 Baixando cérebro '{ambiente}'...") log_print(f"Baixando cerebro '{ambiente}'...")
s3.download_file(BUCKET, s3_key, local_path) s3.download_file(BUCKET, s3_key, local_path)
MODELOS_CARREGADOS[ambiente] = YOLO(local_path) MODELOS_CARREGADOS[ambiente] = YOLO(local_path)
return MODELOS_CARREGADOS[ambiente] return MODELOS_CARREGADOS[ambiente]
@app.post("/detectar")
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
try:
Image.MAX_IMAGE_PIXELS = None
modelo = carregar_modelo_do_s3(ambiente)
conteudo = await file.read()
imagem = Image.open(io.BytesIO(conteudo))
results = modelo(imagem, conf=0.25)
deteccoes = [{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)} for r in results for b in r.boxes]
return {"status": "sucesso", "deteccoes": deteccoes}
except Exception as e:
log_print(f"❌ Erro detecção: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/treinar") def _executar_treino_sync(ambiente: str, pular_triagem: bool) -> dict:
async def treinar(dados: dict): """Execução síncrona do treino — chamada em background thread."""
ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
try:
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/" prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos) objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
modelo_base = carregar_modelo_do_s3(ambiente) modelo_base = carregar_modelo_do_s3(ambiente)
if pular_triagem: if pular_triagem:
log_print(f"🚀 TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})") log_print(f"TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs: if 'Contents' in objs:
for obj in objs['Contents']: for obj in objs['Contents']:
k = obj['Key'] k = obj['Key']
@@ -85,41 +74,39 @@ async def treinar(dados: dict):
new_k = k.replace("novos-treinamentos", "base-oficial") new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k) s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k) s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial") log_print(f"{os.path.basename(k)} -> base-oficial")
else: else:
log_print(f"🚀 INICIANDO TRIAGEM INTELIGENTE: {ambiente}") log_print(f"INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs: if 'Contents' in objs:
for obj in objs['Contents']: for obj in objs['Contents']:
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')): if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
img_key = obj['Key'] img_key = obj['Key']
txt_key = img_key.rsplit('.', 1)[0] + ".txt" txt_key = img_key.rsplit('.', 1)[0] + ".txt"
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key) resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read())) img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False) pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes] confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0 media = sum(confs)/len(confs) if confs else 0
# RÉGUA AJUSTADA PARA 0.30 PARA ACEITAR MAIS IMAGENS
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados" decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"⚖️ {os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}") log_print(f"{os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]: for k in [img_key, txt_key]:
try: try:
s3.head_object(Bucket=BUCKET, Key=k) s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao) new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k) s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k) s3.delete_object(Bucket=BUCKET, Key=k)
except: continue except Exception:
continue
dataset_local = f"/tmp/dataset_{ambiente}" dataset_local = f"/tmp/dataset_{ambiente}"
img_dir, lbl_dir = f"{dataset_local}/train/images", f"{dataset_local}/train/labels" img_dir = f"{dataset_local}/train/images"
if os.path.exists(dataset_local): shutil.rmtree(dataset_local) lbl_dir = f"{dataset_local}/train/labels"
os.makedirs(img_dir, exist_ok=True); os.makedirs(lbl_dir, exist_ok=True) if os.path.exists(dataset_local):
shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
log_print("📥 Baixando Base Oficial (Ouro)...") log_print("Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/") ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro: if 'Contents' in ouro:
for o in ouro['Contents']: for o in ouro['Contents']:
@@ -136,8 +123,7 @@ async def treinar(dados: dict):
with open(yaml_path, 'w') as f: with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f) yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"🏋️ Treinando com {len(os.listdir(img_dir))} fotos. Foco no mAP50...") log_print(f"Treinando com {len(os.listdir(img_dir))} fotos...")
# O YOLO imprimirá o mAP automaticamente no log do Docker aqui:
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True) modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
best = "runs/detect/train/weights/best.pt" best = "runs/detect/train/weights/best.pt"
@@ -145,16 +131,82 @@ async def treinar(dados: dict):
carimbo = datetime.now().strftime("%Y%m%d_%H%M") carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt") s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt") s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS: del MODELOS_CARREGADOS[ambiente] if ambiente in MODELOS_CARREGADOS:
log_print("✅ TREINAMENTO CONCLUÍDO!") del MODELOS_CARREGADOS[ambiente]
log_print("TREINAMENTO CONCLUIDO!")
return {"status": "sucesso", "versao": carimbo} return {"status": "sucesso", "versao": carimbo}
else:
log_print("⚠️ Nenhuma imagem passou na triagem para a Base Oficial.") log_print("Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"} return {"status": "vazio"}
def _treinar_bg(ambiente: str, pular_triagem: bool):
global _treinamento_status
try:
resultado = _executar_treino_sync(ambiente, pular_triagem)
with _status_lock:
if resultado["status"] == "sucesso":
_treinamento_status = {
"status": "concluido",
"ambiente": ambiente,
"versao": resultado.get("versao"),
"detalhe": None,
}
else:
_treinamento_status = {
"status": "vazio",
"ambiente": ambiente,
"versao": None,
"detalhe": "Nenhuma imagem passou na triagem",
}
except Exception as e: except Exception as e:
log_print(f"💥 Erro: {str(e)}") log_print(f"Erro no treino: {str(e)}")
return {"status": "erro", "detalhe": str(e)} with _status_lock:
_treinamento_status = {
"status": "erro",
"ambiente": ambiente,
"versao": None,
"detalhe": str(e)[:500],
}
@app.post("/detectar")
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
try:
Image.MAX_IMAGE_PIXELS = None
modelo = carregar_modelo_do_s3(ambiente)
conteudo = await file.read()
imagem = Image.open(io.BytesIO(conteudo))
results = modelo(imagem, conf=0.25)
deteccoes = [
{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)}
for r in results for b in r.boxes
]
return {"status": "sucesso", "deteccoes": deteccoes}
except Exception as e:
log_print(f"Erro deteccao: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/treinar")
async def treinar(dados: dict):
global _treinamento_status
with _status_lock:
if _treinamento_status["status"] == "running":
raise HTTPException(status_code=409, detail="Treinamento já em andamento")
ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
_treinamento_status = {"status": "running", "ambiente": ambiente, "versao": None, "detalhe": None}
threading.Thread(target=_treinar_bg, args=(ambiente, pular_triagem), daemon=True).start()
return {"status": "iniciado", "ambiente": ambiente}
@app.get("/treinar/status")
async def status_treino():
with _status_lock:
return dict(_treinamento_status)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn