feat: treino assíncrono com polling de status

/treinar agora dispara em background thread e retorna imediato.
Novo endpoint GET /treinar/status expõe estado (idle/running/concluido/vazio/erro).
This commit is contained in:
2026-06-10 14:40:42 -03:00
parent 8bdb28cdc1
commit f7ab4219ca
+140 -88
View File
@@ -5,6 +5,7 @@ import os
import io
import shutil
import yaml
import threading
from PIL import Image
from PIL import Image as PILImage
from datetime import datetime
@@ -15,6 +16,10 @@ BUCKET = os.getenv('BUCKET_NAME', 'ia-gondola-projeto-2024')
MODELOS_CARREGADOS = {}
_treinamento_status = {"status": "idle", "ambiente": None, "versao": None, "detalhe": None}
_status_lock = threading.Lock()
def redimensionar_imagem(caminho):
PILImage.MAX_IMAGE_PIXELS = None
img = PILImage.open(caminho)
@@ -40,18 +45,131 @@ def carregar_modelo_do_s3(ambiente: str):
resp_head = s3.head_object(Bucket=BUCKET, Key=s3_key)
data_s3 = resp_head['LastModified'].timestamp()
if data_s3 > data_local:
log_print(f"🔄 Modelo '{ambiente}' atualizado no S3. Atualizando local...")
log_print(f"Modelo '{ambiente}' atualizado no S3. Atualizando local...")
os.remove(local_path)
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
if ambiente not in MODELOS_CARREGADOS:
if not os.path.exists(local_path):
log_print(f"📥 Baixando cérebro '{ambiente}'...")
log_print(f"Baixando cerebro '{ambiente}'...")
s3.download_file(BUCKET, s3_key, local_path)
MODELOS_CARREGADOS[ambiente] = YOLO(local_path)
return MODELOS_CARREGADOS[ambiente]
def _executar_treino_sync(ambiente: str, pular_triagem: bool) -> dict:
"""Execução síncrona do treino — chamada em background thread."""
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
modelo_base = carregar_modelo_do_s3(ambiente)
if pular_triagem:
log_print(f"TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs:
for obj in objs['Contents']:
k = obj['Key']
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial")
else:
log_print(f"INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs:
for obj in objs['Contents']:
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
img_key = obj['Key']
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"{os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]:
try:
s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
except Exception:
continue
dataset_local = f"/tmp/dataset_{ambiente}"
img_dir = f"{dataset_local}/train/images"
lbl_dir = f"{dataset_local}/train/labels"
if os.path.exists(dataset_local):
shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)
log_print("Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro:
for o in ouro['Contents']:
k = o['Key']
if k.endswith(('.jpg', '.jpeg', '.txt')):
dest = img_dir if not k.endswith('.txt') else lbl_dir
caminho_local = os.path.join(dest, os.path.basename(k))
s3.download_file(BUCKET, k, caminho_local)
if not k.endswith('.txt'):
redimensionar_imagem(caminho_local)
if len(os.listdir(img_dir)) > 0:
yaml_path = f"{dataset_local}/data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"Treinando com {len(os.listdir(img_dir))} fotos...")
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
best = "runs/detect/train/weights/best.pt"
if os.path.exists(best):
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS:
del MODELOS_CARREGADOS[ambiente]
log_print("TREINAMENTO CONCLUIDO!")
return {"status": "sucesso", "versao": carimbo}
log_print("Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"}
def _treinar_bg(ambiente: str, pular_triagem: bool):
global _treinamento_status
try:
resultado = _executar_treino_sync(ambiente, pular_triagem)
with _status_lock:
if resultado["status"] == "sucesso":
_treinamento_status = {
"status": "concluido",
"ambiente": ambiente,
"versao": resultado.get("versao"),
"detalhe": None,
}
else:
_treinamento_status = {
"status": "vazio",
"ambiente": ambiente,
"versao": None,
"detalhe": "Nenhuma imagem passou na triagem",
}
except Exception as e:
log_print(f"Erro no treino: {str(e)}")
with _status_lock:
_treinamento_status = {
"status": "erro",
"ambiente": ambiente,
"versao": None,
"detalhe": str(e)[:500],
}
@app.post("/detectar")
async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
try:
@@ -60,102 +178,36 @@ async def detectar(ambiente: str = Form(...), file: UploadFile = File(...)):
conteudo = await file.read()
imagem = Image.open(io.BytesIO(conteudo))
results = modelo(imagem, conf=0.25)
deteccoes = [{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)} for r in results for b in r.boxes]
deteccoes = [
{"box": [round(x, 2) for x in b.xyxy[0].tolist()], "conf": round(float(b.conf), 2), "class": int(b.cls)}
for r in results for b in r.boxes
]
return {"status": "sucesso", "deteccoes": deteccoes}
except Exception as e:
log_print(f"Erro detecção: {str(e)}")
log_print(f"Erro deteccao: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/treinar")
async def treinar(dados: dict):
ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
try:
prefix_novos = f"treinamento/{ambiente}/novos-treinamentos/"
objs = s3.list_objects_v2(Bucket=BUCKET, Prefix=prefix_novos)
global _treinamento_status
with _status_lock:
if _treinamento_status["status"] == "running":
raise HTTPException(status_code=409, detail="Treinamento já em andamento")
ambiente = dados.get("ambiente", "gondola")
pular_triagem = dados.get("pular_triagem", False)
_treinamento_status = {"status": "running", "ambiente": ambiente, "versao": None, "detalhe": None}
modelo_base = carregar_modelo_do_s3(ambiente)
threading.Thread(target=_treinar_bg, args=(ambiente, pular_triagem), daemon=True).start()
return {"status": "iniciado", "ambiente": ambiente}
if pular_triagem:
log_print(f"🚀 TRIAGEM PULADA: movendo todos os arquivos direto para base-oficial ({ambiente})")
if 'Contents' in objs:
for obj in objs['Contents']:
k = obj['Key']
if k.endswith(('.jpg', '.jpeg', '.png', '.txt')):
new_k = k.replace("novos-treinamentos", "base-oficial")
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
log_print(f"{os.path.basename(k)} -> base-oficial")
else:
log_print(f"🚀 INICIANDO TRIAGEM INTELIGENTE: {ambiente}")
if 'Contents' in objs:
for obj in objs['Contents']:
if obj['Key'].endswith(('.jpg', '.jpeg', '.png')):
img_key = obj['Key']
txt_key = img_key.rsplit('.', 1)[0] + ".txt"
@app.get("/treinar/status")
async def status_treino():
with _status_lock:
return dict(_treinamento_status)
resp_img = s3.get_object(Bucket=BUCKET, Key=img_key)
img_data = Image.open(io.BytesIO(resp_img['Body'].read()))
pred = modelo_base(img_data, conf=0.1, verbose=False)
confs = [float(b.conf) for r in pred for b in r.boxes]
media = sum(confs)/len(confs) if confs else 0
# RÉGUA AJUSTADA PARA 0.30 PARA ACEITAR MAIS IMAGENS
decisao = "base-oficial" if media >= 0.30 or media == 0 else "descartados"
log_print(f"⚖️ {os.path.basename(img_key)}: Conf. {media:.2f} -> {decisao}")
for k in [img_key, txt_key]:
try:
s3.head_object(Bucket=BUCKET, Key=k)
new_k = k.replace("novos-treinamentos", decisao)
s3.copy_object(Bucket=BUCKET, CopySource={'Bucket': BUCKET, 'Key': k}, Key=new_k)
s3.delete_object(Bucket=BUCKET, Key=k)
except: continue
dataset_local = f"/tmp/dataset_{ambiente}"
img_dir, lbl_dir = f"{dataset_local}/train/images", f"{dataset_local}/train/labels"
if os.path.exists(dataset_local): shutil.rmtree(dataset_local)
os.makedirs(img_dir, exist_ok=True); os.makedirs(lbl_dir, exist_ok=True)
log_print("📥 Baixando Base Oficial (Ouro)...")
ouro = s3.list_objects_v2(Bucket=BUCKET, Prefix=f"treinamento/{ambiente}/base-oficial/")
if 'Contents' in ouro:
for o in ouro['Contents']:
k = o['Key']
if k.endswith(('.jpg', '.jpeg', '.txt')):
dest = img_dir if not k.endswith('.txt') else lbl_dir
caminho_local = os.path.join(dest, os.path.basename(k))
s3.download_file(BUCKET, k, caminho_local)
if not k.endswith('.txt'):
redimensionar_imagem(caminho_local)
if len(os.listdir(img_dir)) > 0:
yaml_path = f"{dataset_local}/data.yaml"
with open(yaml_path, 'w') as f:
yaml.dump({'train': img_dir, 'val': img_dir, 'nc': 1, 'names': {0: ambiente}}, f)
log_print(f"🏋️ Treinando com {len(os.listdir(img_dir))} fotos. Foco no mAP50...")
# O YOLO imprimirá o mAP automaticamente no log do Docker aqui:
modelo_base.train(data=yaml_path, epochs=30, imgsz=640, batch=16, device='cpu', plots=True)
best = "runs/detect/train/weights/best.pt"
if os.path.exists(best):
carimbo = datetime.now().strftime("%Y%m%d_%H%M")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/atual/cerebro.pt")
s3.upload_file(best, BUCKET, f"modelos/{ambiente}/versionamento/cerebro_{carimbo}.pt")
if ambiente in MODELOS_CARREGADOS: del MODELOS_CARREGADOS[ambiente]
log_print("✅ TREINAMENTO CONCLUÍDO!")
return {"status": "sucesso", "versao": carimbo}
else:
log_print("⚠️ Nenhuma imagem passou na triagem para a Base Oficial.")
return {"status": "vazio"}
except Exception as e:
log_print(f"💥 Erro: {str(e)}")
return {"status": "erro", "detalhe": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=8000)