package com.hb.proj.api.controller; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.Date; import java.util.HashMap; import java.util.List; import java.util.Map; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.annotation.RestController; import com.hb.proj.diagnose.FeatureExtractor; import com.hb.proj.diagnose.XGBoostDataUtil; import com.hb.proj.diagnose.XGBoostEngine; import com.hb.proj.model.DiagnoseStandard; import com.hb.proj.model.DiagramFeatureSample; import com.hb.proj.model.DiagramSample; import com.hb.proj.model.Hyperparam; import com.hb.proj.model.TrainLog; import com.hb.proj.model.TrainParam; import com.hb.proj.rep.DiagramSampleService; import com.hb.proj.third.RedisMsgProcessor; import com.hb.proj.utils.JacksonUtils; import com.hb.proj.utils.RespVO; import com.hb.proj.utils.RespVOBuilder; import com.hb.xframework.dao.util.PageModel; import com.hb.xframework.util.DateUtil; import com.hb.xframework.util.MapUtils; import jakarta.validation.Valid; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.NotNull; @RestController @RequestMapping("/api") @Validated public class APIController { @Autowired private DiagramSampleService sampleService; @Autowired private XGBoostEngine xgbEngine; @Autowired private RedisMsgProcessor msgProcessor; /** * 获得批量诊断队列的当前大小 * @return */ @RequestMapping("/buffer/size") public RespVO getBufferSize(){ return RespVOBuilder.ok(msgProcessor.getQueueSize()); } /** * 获得功图绘制数据(样本功图) * @param wellId * @param paramCode * @param action * @param refTime * @return */ @RequestMapping("/sample/getDrawData") public RespVO getDrawData(@NotBlank(message = "缺少井名") String wellName,@NotNull(message = "缺少功图时间") Date diagramTime){ DiagramFeatureSample sample=sampleService.find(wellName, diagramTime); if(sample==null) { return RespVOBuilder.error("无数据"); } return RespVOBuilder.ok(buildData(sample)); } private Map buildData(DiagramFeatureSample sample){ String smpTime=DateUtil.format(sample.getDiagramTime(),"yyyy-MM-dd HH:mm:ss"); String[] dtstr=smpTime.split(" "); Map base=new HashMap(); base.put("xTitle", "位移"); base.put("yTitle", "载荷"); base.put("xUnit", "m"); base.put("yUnit", "kN"); base.put("title", "时间: "+dtstr[1]); base.put("type", "GT"); base.put("upMax", null); base.put("downMax", null); base.put("glbMax", sample.getLoadMax()); base.put("glbMin", sample.getLoadMin()); base.put("balance", null); //base.put("stroke", 3); base.put("turnIndex", null); /* base.put("yFixedMax", wparam.getYaxisMax()); base.put("yFixedMin", wparam.getYaxisMin()); base.put("xFixedMax", wparam.getXaxisMax()); base.put("xFixedMin", wparam.getXaxisMin());*/ String[] descTxt= {"隶属单位:","井名:"+sample.getWellName(),"测试日期:"+dtstr[0],"冲次:"}; Map serial=new HashMap(); List serials=convert(sample.getDisps()); Float stroke=Collections.max(serials); base.put("stroke", stroke); serial.put("xvals", serials); serial.put("yvals", convert(sample.getLoads())); serial.put("smpTime", smpTime); serial.put("dataId", sample.getSampleId()); return MapUtils.build("base",base,"serial",serial,"descTxt",descTxt); } private List convert(String vals){ String[] ary=vals.split(","); List datas=new ArrayList<>(ary.length); for(String d : ary) { datas.add(Float.parseFloat(d)); } return datas; } /** * 加载指定诊断标准的样本功图 */ @RequestMapping("/sample/load") public RespVO> load(@NotBlank(message="缺少诊断标准编号") String diagnoseId, @RequestParam(value = "currentPage", defaultValue = "1") Integer pageNum, @RequestParam(value = "pageSize", defaultValue = "20") Integer pageSize) { return RespVOBuilder.ok(sampleService.load(diagnoseId, pageNum, pageSize)); } /** * 删除指定单个样本功图 * @param sampleId * @return */ @PostMapping("/sample/del") public RespVO del(@NotNull(message="缺少样本编号") Integer sampleId){ sampleService.delSample(sampleId); return RespVOBuilder.ok(); } /** * 增加样品功图 * @param sample * @return */ @PostMapping("/sample/add") public RespVO add(@Valid @RequestBody DiagramSample srcSample){ if(sampleService.exists(srcSample.getWellName(),srcSample.getDiagramTime())) { return RespVOBuilder.error("该样本已添加过"); } DiagramFeatureSample sample=FeatureExtractor.extract(srcSample); sampleService.add(sample); return RespVOBuilder.ok(); } /** * 训练模型 * @param trainParam * @return */ @PostMapping("/sample/train") public RespVO train(@Valid TrainParam trainParam){ List samples=sampleService.loadSamples(); if(samples==null || samples.size()<5) { return RespVOBuilder.error("样本数据太少"); } Map standardMapping=sampleService.loadDiagnoseStandard(); //对样本的类别标签进行编号 List trainSamples=new ArrayList<>(); List validSamples=new ArrayList<>(); int label=0; for(DiagramFeatureSample smp : samples) { if(!standardMapping.containsKey(smp.getDiagnoseId())) { continue; } label=((Number)standardMapping.get(smp.getDiagnoseId())).intValue(); smp.setCategoryTag(label); if(DiagramFeatureSample.TYPE_VALID.equals(smp.getSampleType())) { validSamples.add(smp); } else { trainSamples.add(smp); } } if(trainSamples.size()==0) { return RespVOBuilder.error("未找到符合要求的样本数据"); } //未明确指定验证样本集时,以训练集的前20%作为验证集,其余作为训练集 if(validSamples.size()==0) { int vcount=(int)(Math.ceil(trainSamples.size()*0.2)); validSamples.addAll(trainSamples.subList(0, vcount)); trainSamples=trainSamples.subList(vcount,trainSamples.size()); } Hyperparam hyperparam=new Hyperparam(); hyperparam.setNumClass(standardMapping.size()); hyperparam.setEta(trainParam.getEta()); hyperparam.setMaxDepth(trainParam.getMaxDepth()); xgbEngine.train(hyperparam, trainParam.getTrainRound(), XGBoostDataUtil.build(trainSamples, true), XGBoostDataUtil.build(validSamples, true)); //存储训练日志 TrainLog trainLog=new TrainLog(JacksonUtils.getJSON(standardMapping)); sampleService.saveTrainLog(trainLog); return RespVOBuilder.ok(); } /** * 预测样本 * @return */ @PostMapping("/sample/predict") public RespVO predict(){ List testSamples=sampleService.loadSamples("test"); float[][] rst=xgbEngine.predict(XGBoostDataUtil.build(testSamples, false)); return RespVOBuilder.ok(rst); } /** * 单一功图诊断 * @return */ @PostMapping("/diagram/diagnose") public RespVO diagnose(@Valid @RequestBody DiagramSample srcSample){ DiagramFeatureSample sample=FeatureExtractor.extract(srcSample); float[][] rst=xgbEngine.predict(XGBoostDataUtil.build(Arrays.asList(sample), false)); int clsLabel=(int)rst[0][0]; Map labelStdMapping=sampleService.loadLabelStdMapping(); return RespVOBuilder.ok(labelStdMapping.get(clsLabel)); } @PostMapping("/sample/predict2") public RespVO diagnose2(){ /* List trainSamples=sampleService.loadSamples("train"); List validSamples=sampleService.loadSamples("valid"); List testSamples=sampleService.loadSamples("test"); float[][] rst=XGBootst.trainAndPredict(3, 20, XGBoostDataUtil.build(trainSamples, true), XGBoostDataUtil.build(validSamples, true), XGBoostDataUtil.build(testSamples, false)); return RespVOBuilder.ok(rst); */ return RespVOBuilder.ok(); } }