快乐十分开奖号码搜索|taiyuan快乐十分

星际争霸2 AI开发(持续更新)

准备

我的环境是python3.6,sc2包0.11.1
机器学习包下载链接:pysc2
地图下载链接:maps
游戏下载链接:国际服
国服
pysc2是DeepMind开发的星际争霸Ⅱ学习环境。 它是封装星际争霸Ⅱ机器学习API,同时也提供Python增强学习环境。
以神族为例编写代码,神族建筑科技图如下:

采矿

# -*- encoding: utf-8 -*-
'''
@File    :   __init__.py.py    
@Modify Time      @Author       @Desciption
------------      -------       -----------
2019/11/3 12:32   Jonas           None
'''

import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer


class SentdeBot(sc2.BotAI):
    async def on_step(self, iteration: int):
        await self.distribute_workers()


run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Easy)
],realtime = True)

注意
game_data.py的assert self.id != 0注释掉
pixel_map.py的assert self.bits_per_pixel % 8 == 0, "Unsupported pixel density"注释掉
否则会报错

运行结果如下,农民开始采矿

可以正常采矿

建造农民和水晶塔

import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer
from sc2.constants import *


class SentdeBot(sc2.BotAI):
    async def on_step(self, iteration: int):
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()

    # 建造农民
    async def build_workers(self):
        # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
        for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
            # 是否有50晶体矿
            if self.can_afford(UnitTypeId.PROBE):
                await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且水晶不是正在建造
        if self.supply_left<5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON,near=nexuses.first)

## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Easy)
],realtime = True)

运行结果如下,基地造农民,农民造水晶

?#21344;?#27668;体和开矿

代码如下

import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer
from sc2.constants import *


class SentdeBot(sc2.BotAI):
    async def on_step(self, iteration: int):
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()

    # 建造农民
    async def build_workers(self):
        # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
        for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
            # 是否有50晶体矿
            if self.can_afford(UnitTypeId.PROBE):
                await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0,nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0,vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR,vaspene))

    ## 开矿
    async def expand(self):
        if self.units(UnitTypeId.NEXUS).amount<3 and self.can_afford(UnitTypeId.NEXUS):
            await self.expand_now()

## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Easy)
], realtime=False)

run_game的realtime设置成False,可以在加速模式下运行游戏。
运行效果如下:

可以建造吸收厂和开矿

建造军队

import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer
from sc2.constants import *


class SentdeBot(sc2.BotAI):
    async def on_step(self, iteration: int):
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()
        await self.offensive_force_buildings()
        await self.build_offensive_force()

    # 建造农民
    async def build_workers(self):
        # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
        for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
            # 是否有50晶体矿
            if self.can_afford(UnitTypeId.PROBE):
                await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0,nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0,vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR,vaspene))

    ## 开矿
    async def expand(self):
        if self.units(UnitTypeId.NEXUS).amount<2 and self.can_afford(UnitTypeId.NEXUS):
            await self.expand_now()

    ## 建造进攻性建筑
    async def offensive_force_buildings(self):
        if self.units(UnitTypeId.PYLON).ready.exists:
            pylon = self.units(UnitTypeId.PYLON).ready.random
            if self.units(UnitTypeId.PYLON).ready.exists:
                # 根据神族建筑科技图,折跃门建造过后才可以建造控制核心
                if self.units(UnitTypeId.GATEWAY).ready.exists:
                    if not self.units(UnitTypeId.CYBERNETICSCORE):
                        if self.can_afford(UnitTypeId.CYBERNETICSCORE) and not self.already_pending(UnitTypeId.CYBERNETICSCORE):
                            await self.build(UnitTypeId.CYBERNETICSCORE,near = pylon)
                # 否则建造折跃门
                else:
                    if self.can_afford(UnitTypeId.GATEWAY) and not self.already_pending(UnitTypeId.GATEWAY):
                        await self.build(UnitTypeId.GATEWAY,near=pylon)

    # 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        for gw in self.units(UnitTypeId.GATEWAY).ready.noqueue:
            if self.can_afford(UnitTypeId.STALKER) and self.supply_left>0:
                await self.do(gw.train(UnitTypeId.STALKER))



## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Easy)
], realtime=False)

运行结果如下:

可以看到,我们建造了折跃门和控制核心并训练了追猎者

控制部队进攻

代码如下


import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer
from sc2.constants import *
import random

class SentdeBot(sc2.BotAI):
    async def on_step(self, iteration: int):
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()
        await self.offensive_force_buildings()
        await self.build_offensive_force()
        await self.attack()

    # 建造农民
    async def build_workers(self):
        # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
        for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
            # 是否有50晶体矿
            if self.can_afford(UnitTypeId.PROBE):
                await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0,nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0,vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR,vaspene))

    ## 开矿
    async def expand(self):
        if self.units(UnitTypeId.NEXUS).amount<3 and self.can_afford(UnitTypeId.NEXUS):
            await self.expand_now()

    ## 建造进攻性建筑
    async def offensive_force_buildings(self):
        if self.units(UnitTypeId.PYLON).ready.exists:
            pylon = self.units(UnitTypeId.PYLON).ready.random
            # 根据神族建筑科技图,折跃门建造过后才可以建造控制核心
            if self.units(UnitTypeId.GATEWAY).ready.exists and not self.units(UnitTypeId.CYBERNETICSCORE):
                if self.can_afford(UnitTypeId.CYBERNETICSCORE) and not self.already_pending(UnitTypeId.CYBERNETICSCORE):
                    await self.build(UnitTypeId.CYBERNETICSCORE,near = pylon)
            # 否则建造折跃门
            elif len(self.units(UnitTypeId.GATEWAY))<=3:
                if self.can_afford(UnitTypeId.GATEWAY) and not self.already_pending(UnitTypeId.GATEWAY):
                    await self.build(UnitTypeId.GATEWAY,near=pylon)

    ## 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        for gw in self.units(UnitTypeId.GATEWAY).ready.noqueue:
            if self.can_afford(UnitTypeId.STALKER) and self.supply_left>0:
                await self.do(gw.train(UnitTypeId.STALKER))

    ## 寻找目标
    def find_target(self,state):
        if len(self.known_enemy_units)>0:
            # 随机选取敌方单位
            return random.choice(self.known_enemy_units)
        elif len(self.known_enemy_units)>0:
            # 随机选取敌方建筑
            return random.choice(self.known_enemy_structures)
        else:
            # 返回敌方出生点位
            return self.enemy_start_locations[0]

    ## 进攻
    async def attack(self):
        # 追猎者数量超过15个开始进攻
        if self.units(UnitTypeId.STALKER).amount>15:
            for s in self.units(UnitTypeId.STALKER).idle:
                await self.do(s.attack(self.find_target(self.state)))

        # 防卫模式:视野范围内存在敌人,开始攻击
        if self.units(UnitTypeId.STALKER).amount>5:
            if len(self.known_enemy_units)>0:
                for s in self.units(UnitTypeId.STALKER).idle:
                    await self.do(s.attack(random.choice(self.known_enemy_units)))

## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Medium)
], realtime=False)

运行结果如下


可以看到,4个折跃门训练追猎者并发动进攻。

击败困?#35757;?#33041;

我们目前的代码只能击败中等?#22270;?#21333;电脑,那么如?#20301;?#36133;困?#35757;縋阅兀?br /> 代码如下


import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer
from sc2.constants import *
import random


class SentdeBot(sc2.BotAI):
    def __init__(self):
        # 经过计算,每分钟大约165迭代次数
        self.ITERATIONS_PER_MINUTE = 165
        # 最大农民数量
        self.MAX_WORKERS = 65

    async def on_step(self, iteration: int):
        self.iteration = iteration
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()
        await self.offensive_force_buildings()
        await self.build_offensive_force()
        await self.attack()

    # 建造农民
    async def build_workers(self):
        # 星灵枢钮*16(一个基地配备16个农民)大于农民数量并且现有农民数量小于MAX_WORKERS
        if len(self.units(UnitTypeId.NEXUS))*16>len(self.units(UnitTypeId.PROBE)) and len(self.units(UnitTypeId.PROBE))<self.MAX_WORKERS:
                # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
                for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
                    # 是否有50晶体矿建造农民
                    if self.can_afford(UnitTypeId.PROBE):
                        await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0,nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0,vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR,vaspene))

    ## 开矿
    async def expand(self):
        # (self.iteration / self.ITERATIONS_PER_MINUTE)是一个缓慢递增的值,动态开矿
        if self.units(UnitTypeId.NEXUS).amount<self.iteration / self.ITERATIONS_PER_MINUTE and self.can_afford(UnitTypeId.NEXUS):
            await self.expand_now()

    ## 建造进攻性建筑
    async def offensive_force_buildings(self):
        print(self.iteration / self.ITERATIONS_PER_MINUTE)
        if self.units(UnitTypeId.PYLON).ready.exists:
            pylon = self.units(UnitTypeId.PYLON).ready.random
            # 根据神族建筑科技图,折跃门建造过后才可以建造控制核心
            if self.units(UnitTypeId.GATEWAY).ready.exists and not self.units(UnitTypeId.CYBERNETICSCORE):
                if self.can_afford(UnitTypeId.CYBERNETICSCORE) and not self.already_pending(UnitTypeId.CYBERNETICSCORE):
                    await self.build(UnitTypeId.CYBERNETICSCORE, near=pylon)
            # 否则建造折跃门
            # (self.iteration / self.ITERATIONS_PER_MINUTE)/2 是一个缓慢递增的值
            elif len(self.units(UnitTypeId.GATEWAY)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
                if self.can_afford(UnitTypeId.GATEWAY) and not self.already_pending(UnitTypeId.GATEWAY):
                    await self.build(UnitTypeId.GATEWAY, near=pylon)
            # 控制核心存在的情况下建造星门
            if self.units(UnitTypeId.CYBERNETICSCORE).ready.exists:
                if len(self.units(UnitTypeId.STARGATE)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
                    if self.can_afford(UnitTypeId.STARGATE) and not self.already_pending(UnitTypeId.STARGATE):
                        await self.build(UnitTypeId.STARGATE, near=pylon)

    ## 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        for gw in self.units(UnitTypeId.GATEWAY).ready.noqueue:
            if not self.units(UnitTypeId.STALKER).amount > self.units(UnitTypeId.VOIDRAY).amount:

                if self.can_afford(UnitTypeId.STALKER) and self.supply_left > 0:
                    await self.do(gw.train(UnitTypeId.STALKER))

        for sg in self.units(UnitTypeId.STARGATE).ready.noqueue:
            if self.can_afford(UnitTypeId.VOIDRAY) and self.supply_left > 0:
                await self.do(sg.train(UnitTypeId.VOIDRAY))

    ## 寻找目标
    def find_target(self,state):
        if len(self.known_enemy_units)>0:
            # 随机选取敌方单位
            return random.choice(self.known_enemy_units)
        elif len(self.known_enemy_units)>0:
            # 随机选取敌方建筑
            return random.choice(self.known_enemy_structures)
        else:
            # 返回敌方出生点位
            return self.enemy_start_locations[0]

    ## 进攻
    async def attack(self):
        # {UNIT: [n to fight, n to defend]}
        aggressive_units = {UnitTypeId.STALKER: [15, 5],
                            UnitTypeId.VOIDRAY: [8, 3]}

        for UNIT in aggressive_units:
            # 攻击模式
            if self.units(UNIT).amount > aggressive_units[UNIT][0] and self.units(UNIT).amount > aggressive_units[UNIT][
                1]:
                for s in self.units(UNIT).idle:
                    await self.do(s.attack(self.find_target(self.state)))
            # 防卫模式
            elif self.units(UNIT).amount > aggressive_units[UNIT][1]:
                if len(self.known_enemy_units) > 0:
                    for s in self.units(UNIT).idle:
                        await self.do(s.attack(random.choice(self.known_enemy_units)))
## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Hard)
], realtime=False)

运行结果如下

可以看到,击败了困难人族电脑,但是电脑选择了rush战术,我们写得AI脚本会输掉游戏。显然,这不是最佳方案。
“只有AI才能拯救我的胜率?#20445;?#35831;看下文。

采集地图数据

这次我们只造一个折跃门,全力通过星门造虚空光辉舰
修改offensive_force_buildings(self)方法的判断

elif len(self.units(GATEWAY)) < 1:
                if self.can_afford(GATEWAY) and not self.already_pending(GATEWAY):
                    await self.build(GATEWAY, near=pylon)

注释或者?#22659;齜uild_offensive_force(self)的建造追猎者的代码

        ## 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        # for gw in self.units(UnitTypeId.GATEWAY).ready.noqueue:
        #     if not self.units(UnitTypeId.STALKER).amount > self.units(UnitTypeId.VOIDRAY).amount:
        #
        #         if self.can_afford(UnitTypeId.STALKER) and self.supply_left > 0:
        #             await self.do(gw.train(UnitTypeId.STALKER))

        for sg in self.units(UnitTypeId.STARGATE).ready.noqueue:
            if self.can_afford(UnitTypeId.VOIDRAY) and self.supply_left > 0:
                await self.do(sg.train(UnitTypeId.VOIDRAY))

attack(self)中的aggressive_units注释掉Stalker
导入numpy和cv2库

game_data = np.zeros((self.game_info.map_size[1], self.game_info.map_size[0], 3), np.uint8)

建立以地图Heigt为行,Width为列的三维矩阵

for nexus in self.units(NEXUS):
            nex_pos = nexus.position
            print(nex_pos)
            cv2.circle(game_data, (int(nex_pos[0]), int(nex_pos[1])), 10, (0, 255, 0), -1)  # BGR

遍历星灵枢纽,获取下一个位置,画圆,circle(承载圆的img, 圆心, 半径, 颜色, thickness=-1表示填充)
接下来我?#19988;?#22402;直翻转三维矩阵,因为我们建立的矩阵左上角是原点(0,0),纵坐标向下延申,横坐标向右延申。翻转之后就成了正常的坐标系。

flipped = cv2.flip(game_data, 0)

图像缩放,达到可视化最?#36873;?/p>

        resized = cv2.resize(flipped, dsize=None, fx=2, fy=2)
        cv2.imshow('Intel', resized)
        cv2.waitKey(1)

?#38142;耍?#23436;整代码如下

import sc2
from sc2 import run_game, maps, Race, Difficulty
from sc2.player import Bot, Computer
from sc2.constants import *
import random
import numpy as np
import cv2


class SentdeBot(sc2.BotAI):
    def __init__(self):
        # 经过计算,每分钟大约165迭代次数
        self.ITERATIONS_PER_MINUTE = 165
        # 最大农民数量
        self.MAX_WORKERS = 65

    async def on_step(self, iteration: int):
        self.iteration = iteration
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()
        await self.offensive_force_buildings()
        await self.build_offensive_force()
        await self.intel()
        await self.attack()

    async def intel(self):
        # 根据地图建立的三维矩阵
        game_data = np.zeros((self.game_info.map_size[1], self.game_info.map_size[0], 3), np.uint8)
        for nexus in self.units(UnitTypeId.NEXUS):
            nex_pos = nexus.position
            # circle(承载圆的img, 圆心, 半径, 颜色, thickness=-1表示填充)
            # 记录星灵枢纽的位置
            cv2.circle(game_data, (int(nex_pos[0]), int(nex_pos[1])), 10, (0, 255, 0), -1)
        # 图像翻转垂?#26412;?#20687;
        flipped = cv2.flip(game_data, 0)
        # 图像缩放
        # cv2.resize(原图像,输出图像的大小,width方向的缩放比例,height方向缩放的比例)
        resized = cv2.resize(flipped, dsize=None, fx=2, fy=2)
        cv2.imshow('Intel', resized)

        # cv2.waitKey(每Xms刷新图像)
        cv2.waitKey(1)

    # 建造农民
    async def build_workers(self):
        # 星灵枢钮*16(一个基地配备16个农民)大于农民数量并且现有农民数量小于MAX_WORKERS
        if len(self.units(UnitTypeId.NEXUS)) * 16 > len(self.units(UnitTypeId.PROBE)) and len(
                self.units(UnitTypeId.PROBE)) < self.MAX_WORKERS:
            # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
            for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
                # 是否有50晶体矿建造农民
                if self.can_afford(UnitTypeId.PROBE):
                    await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0, nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0, vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR, vaspene))

    ## 开矿
    async def expand(self):
        # (self.iteration / self.ITERATIONS_PER_MINUTE)是一个缓慢递增的值,动态开矿
        if self.units(UnitTypeId.NEXUS).amount < self.iteration / self.ITERATIONS_PER_MINUTE and self.can_afford(
                UnitTypeId.NEXUS):
            await self.expand_now()

    ## 建造进攻性建筑
    async def offensive_force_buildings(self):
        print(self.iteration / self.ITERATIONS_PER_MINUTE)
        if self.units(UnitTypeId.PYLON).ready.exists:
            pylon = self.units(UnitTypeId.PYLON).ready.random
            # 根据神族建筑科技图,折跃门建造过后才可以建造控制核心
            if self.units(UnitTypeId.GATEWAY).ready.exists and not self.units(UnitTypeId.CYBERNETICSCORE):
                if self.can_afford(UnitTypeId.CYBERNETICSCORE) and not self.already_pending(UnitTypeId.CYBERNETICSCORE):
                    await self.build(UnitTypeId.CYBERNETICSCORE, near=pylon)
            # 否则建造折跃门
            # (self.iteration / self.ITERATIONS_PER_MINUTE)/2 是一个缓慢递增的值
            # elif len(self.units(UnitTypeId.GATEWAY)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
            elif len(self.units(UnitTypeId.GATEWAY)) < 1:
                if self.can_afford(UnitTypeId.GATEWAY) and not self.already_pending(UnitTypeId.GATEWAY):
                    await self.build(UnitTypeId.GATEWAY, near=pylon)
            # 控制核心存在的情况下建造星门
            if self.units(UnitTypeId.CYBERNETICSCORE).ready.exists:
                if len(self.units(UnitTypeId.STARGATE)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
                    if self.can_afford(UnitTypeId.STARGATE) and not self.already_pending(UnitTypeId.STARGATE):
                        await self.build(UnitTypeId.STARGATE, near=pylon)

    ## 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        for sg in self.units(UnitTypeId.STARGATE).ready.noqueue:
            if self.can_afford(UnitTypeId.VOIDRAY) and self.supply_left > 0:
                await self.do(sg.train(UnitTypeId.VOIDRAY))

    ## 寻找目标
    def find_target(self, state):
        if len(self.known_enemy_units) > 0:
            # 随机选取敌方单位
            return random.choice(self.known_enemy_units)
        elif len(self.known_enemy_units) > 0:
            # 随机选取敌方建筑
            return random.choice(self.known_enemy_structures)
        else:
            # 返回敌方出生点位
            return self.enemy_start_locations[0]

    ## 进攻
    async def attack(self):
        # {UNIT: [n to fight, n to defend]}
        aggressive_units = {UnitTypeId.VOIDRAY: [8, 3]}

        for UNIT in aggressive_units:
            # 攻击模式
            if self.units(UNIT).amount > aggressive_units[UNIT][0] and self.units(UNIT).amount > aggressive_units[UNIT][1]:
                for s in self.units(UNIT).idle:
                    await self.do(s.attack(self.find_target(self.state)))
            # 防卫模式
            elif self.units(UNIT).amount > aggressive_units[UNIT][1]:
                if len(self.known_enemy_units) > 0:
                    for s in self.units(UNIT).idle:
                        await self.do(s.attack(random.choice(self.known_enemy_units)))


## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Hard)
], realtime=False)

运行结果如下

采集到了地图位置。

侦察

在intel(self)里创建一个?#20540;鋎raw_dict,UnitTypeId作为key,半径和颜色是value


        draw_dict = {
            UnitTypeId.NEXUS: [15, (0, 255, 0)],
            UnitTypeId.PYLON: [3, (20, 235, 0)],
            UnitTypeId.PROBE: [1, (55, 200, 0)],
            UnitTypeId.ASSIMILATOR: [2, (55, 200, 0)],
            UnitTypeId.GATEWAY: [3, (200, 100, 0)],
            UnitTypeId.CYBERNETICSCORE: [3, (150, 150, 0)],
            UnitTypeId.STARGATE: [5, (255, 0, 0)],
            UnitTypeId.ROBOTICSFACILITY: [5, (215, 155, 0)],

            UnitTypeId.VOIDRAY: [3, (255, 100, 0)]
        }

迭代同上

for unit_type in draw_dict:
            for unit in self.units(unit_type).ready:
                pos = unit.position
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), draw_dict[unit_type][0], draw_dict[unit_type][1], -1)

存储三族的主基地名称(星灵枢纽,指挥?#34892;模?#23413;化场),刻画敌方建筑。

# 主基地名称
        main_base_names = ["nexus", "supplydepot", "hatchery"]
        # 记录敌方基地位置
        for enemy_building in self.known_enemy_structures:
            pos = enemy_building.position
            if enemy_building.name.lower() not in main_base_names:
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), 5, (200, 50, 212), -1)
        for enemy_building in self.known_enemy_structures:
            pos = enemy_building.position
            if enemy_building.name.lower() in main_base_names:
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), 15, (0, 0, 255), -1)

刻画敌方单位,如果是农民画得小些,其他单位则画大些。

        for enemy_unit in self.known_enemy_units:

            if not enemy_unit.is_structure:
                worker_names = ["probe", "scv", "drone"]
                # if that unit is a PROBE, SCV, or DRONE... it's a worker
                pos = enemy_unit.position
                if enemy_unit.name.lower() in worker_names:
                    cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (55, 0, 155), -1)
                else:
                    cv2.circle(game_data, (int(pos[0]), int(pos[1])), 3, (50, 0, 215), -1)

在offensive_force_buildings(self)方法中添加建造机械台

            if self.units(CYBERNETICSCORE).ready.exists:
                if len(self.units(ROBOTICSFACILITY)) < 1:
                    if self.can_afford(ROBOTICSFACILITY) and not self.already_pending(ROBOTICSFACILITY):
                        await self.build(ROBOTICSFACILITY, near=pylon)

创建scout(),训练Observer

async def scout(self):
        if len(self.units(OBSERVER)) > 0:
            scout = self.units(OBSERVER)[0]
            if scout.is_idle:
                enemy_location = self.enemy_start_locations[0]
                move_to = self.random_location_variance(enemy_location)
                print(move_to)
                await self.do(scout.move(move_to))

        else:
            for rf in self.units(ROBOTICSFACILITY).ready.noqueue:
                if self.can_afford(OBSERVER) and self.supply_left > 0:
                    await self.do(rf.train(OBSERVER))

生成随机位置,很简单。意思是横坐标累计递增-0.2和0.2倍的横坐标,限制条件为如果x超过横坐标,那么就是横坐标最大值。
纵坐标同理。

    def random_location_variance(self, enemy_start_location):
        x = enemy_start_location[0]
        y = enemy_start_location[1]

        x += ((random.randrange(-20, 20))/100) * enemy_start_location[0]
        y += ((random.randrange(-20, 20))/100) * enemy_start_location[1]

        if x < 0:
            x = 0
        if y < 0:
            y = 0
        if x > self.game_info.map_size[0]:
            x = self.game_info.map_size[0]
        if y > self.game_info.map_size[1]:
            y = self.game_info.map_size[1]

        go_to = position.Point2(position.Pointlike((x,y)))
        return go_to

完整代码如下

# -*- encoding: utf-8 -*-
'''
@File    :   demo.py
@Modify Time      @Author       @Desciption
------------      -------       -----------
2019/11/3 12:32   Jonas           None
'''

import sc2
from sc2 import run_game, maps, Race, Difficulty, position
from sc2.player import Bot, Computer
from sc2.constants import *
import random
import numpy as np
import cv2


class SentdeBot(sc2.BotAI):
    def __init__(self):
        # 经过计算,每分钟大约165迭代次数
        self.ITERATIONS_PER_MINUTE = 165
        # 最大农民数量
        self.MAX_WORKERS = 50

    async def on_step(self, iteration: int):
        self.iteration = iteration
        await self.scout()
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()
        await self.offensive_force_buildings()
        await self.build_offensive_force()
        await self.intel()
        await self.attack()

    ## 侦察
    async def scout(self):
        if len(self.units(UnitTypeId.OBSERVER)) > 0:
            scout = self.units(UnitTypeId.OBSERVER)[0]
            if scout.is_idle:
                enemy_location = self.enemy_start_locations[0]
                move_to = self.random_location_variance(enemy_location)
                print(move_to)
                await self.do(scout.move(move_to))

        else:
            for rf in self.units(UnitTypeId.ROBOTICSFACILITY).ready.noqueue:
                if self.can_afford(UnitTypeId.OBSERVER) and self.supply_left > 0:
                    await self.do(rf.train(UnitTypeId.OBSERVER))

    async def intel(self):
        game_data = np.zeros((self.game_info.map_size[1], self.game_info.map_size[0], 3), np.uint8)

        # UnitTypeId作为key,半径和颜色是value
        draw_dict = {
            UnitTypeId.NEXUS: [15, (0, 255, 0)],
            UnitTypeId.PYLON: [3, (20, 235, 0)],
            UnitTypeId.PROBE: [1, (55, 200, 0)],
            UnitTypeId.ASSIMILATOR: [2, (55, 200, 0)],
            UnitTypeId.GATEWAY: [3, (200, 100, 0)],
            UnitTypeId.CYBERNETICSCORE: [3, (150, 150, 0)],
            UnitTypeId.STARGATE: [5, (255, 0, 0)],
            UnitTypeId.ROBOTICSFACILITY: [5, (215, 155, 0)],

            UnitTypeId.VOIDRAY: [3, (255, 100, 0)],
            # OBSERVER: [3, (255, 255, 255)],
        }

        for unit_type in draw_dict:
            for unit in self.units(unit_type).ready:
                pos = unit.position
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), draw_dict[unit_type][0], draw_dict[unit_type][1], -1)

        # 主基地名称
        main_base_names = ["nexus", "supplydepot", "hatchery"]
        # 记录敌方基地位置
        for enemy_building in self.known_enemy_structures:
            pos = enemy_building.position
            # 不是主基地建筑,画小一些
            if enemy_building.name.lower() not in main_base_names:
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), 5, (200, 50, 212), -1)
        for enemy_building in self.known_enemy_structures:
            pos = enemy_building.position
            if enemy_building.name.lower() in main_base_names:
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), 15, (0, 0, 255), -1)

        for enemy_unit in self.known_enemy_units:

            if not enemy_unit.is_structure:
                worker_names = ["probe", "scv", "drone"]
                # if that unit is a PROBE, SCV, or DRONE... it's a worker
                pos = enemy_unit.position
                if enemy_unit.name.lower() in worker_names:
                    cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (55, 0, 155), -1)
                else:
                    cv2.circle(game_data, (int(pos[0]), int(pos[1])), 3, (50, 0, 215), -1)

        for obs in self.units(UnitTypeId.OBSERVER).ready:
            pos = obs.position
            cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (255, 255, 255), -1)

        # flip horizontally to make our final fix in visual representation:
        flipped = cv2.flip(game_data, 0)
        resized = cv2.resize(flipped, dsize=None, fx=2, fy=2)

        cv2.imshow('Intel', resized)
        cv2.waitKey(1)

    def random_location_variance(self, enemy_start_location):
        x = enemy_start_location[0]
        y = enemy_start_location[1]

        x += ((random.randrange(-20, 20)) / 100) * enemy_start_location[0]
        y += ((random.randrange(-20, 20)) / 100) * enemy_start_location[1]

        if x < 0:
            x = 0
        if y < 0:
            y = 0
        if x > self.game_info.map_size[0]:
            x = self.game_info.map_size[0]
        if y > self.game_info.map_size[1]:
            y = self.game_info.map_size[1]

        go_to = position.Point2(position.Pointlike((x, y)))
        return go_to

    # 建造农民
    async def build_workers(self):
        # 星灵枢钮*16(一个基地配备16个农民)大于农民数量并且现有农民数量小于MAX_WORKERS
        if len(self.units(UnitTypeId.NEXUS)) * 16 > len(self.units(UnitTypeId.PROBE)) and len(
                self.units(UnitTypeId.PROBE)) < self.MAX_WORKERS:
            # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
            for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
                # 是否有50晶体矿建造农民
                if self.can_afford(UnitTypeId.PROBE):
                    await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0, nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0, vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR, vaspene))

    ## 开矿
    async def expand(self):
        # (self.iteration / self.ITERATIONS_PER_MINUTE)是一个缓慢递增的值,动态开矿
        if self.units(UnitTypeId.NEXUS).amount < self.iteration / self.ITERATIONS_PER_MINUTE and self.can_afford(
                UnitTypeId.NEXUS):
            await self.expand_now()

    ## 建造进攻性建筑
    async def offensive_force_buildings(self):
        print(self.iteration / self.ITERATIONS_PER_MINUTE)
        if self.units(UnitTypeId.PYLON).ready.exists:
            pylon = self.units(UnitTypeId.PYLON).ready.random
            # 根据神族建筑科技图,折跃门建造过后才可以建造控制核心
            if self.units(UnitTypeId.GATEWAY).ready.exists and not self.units(UnitTypeId.CYBERNETICSCORE):
                if self.can_afford(UnitTypeId.CYBERNETICSCORE) and not self.already_pending(UnitTypeId.CYBERNETICSCORE):
                    await self.build(UnitTypeId.CYBERNETICSCORE, near=pylon)
            # 否则建造折跃门
            # (self.iteration / self.ITERATIONS_PER_MINUTE)/2 是一个缓慢递增的值
            # elif len(self.units(UnitTypeId.GATEWAY)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
            elif len(self.units(UnitTypeId.GATEWAY)) < 1:
                if self.can_afford(UnitTypeId.GATEWAY) and not self.already_pending(UnitTypeId.GATEWAY):
                    await self.build(UnitTypeId.GATEWAY, near=pylon)
            # 控制核心存在的情况下建造机械台
            if self.units(UnitTypeId.CYBERNETICSCORE).ready.exists:
                if len(self.units(UnitTypeId.ROBOTICSFACILITY)) < 1:
                    if self.can_afford(UnitTypeId.ROBOTICSFACILITY) and not self.already_pending(
                            UnitTypeId.ROBOTICSFACILITY):
                        await self.build(UnitTypeId.ROBOTICSFACILITY, near=pylon)

            # 控制核心存在的情况下建造星门
            if self.units(UnitTypeId.CYBERNETICSCORE).ready.exists:
                if len(self.units(UnitTypeId.STARGATE)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
                    if self.can_afford(UnitTypeId.STARGATE) and not self.already_pending(UnitTypeId.STARGATE):
                        await self.build(UnitTypeId.STARGATE, near=pylon)

    ## 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        # for gw in self.units(UnitTypeId.GATEWAY).ready.noqueue:
        #     if not self.units(UnitTypeId.STALKER).amount > self.units(UnitTypeId.VOIDRAY).amount:
        #
        #         if self.can_afford(UnitTypeId.STALKER) and self.supply_left > 0:
        #             await self.do(gw.train(UnitTypeId.STALKER))

        for sg in self.units(UnitTypeId.STARGATE).ready.noqueue:
            if self.can_afford(UnitTypeId.VOIDRAY) and self.supply_left > 0:
                await self.do(sg.train(UnitTypeId.VOIDRAY))

    ## 寻找目标
    def find_target(self, state):
        if len(self.known_enemy_units) > 0:
            # 随机选取敌方单位
            return random.choice(self.known_enemy_units)
        elif len(self.known_enemy_units) > 0:
            # 随机选取敌方建筑
            return random.choice(self.known_enemy_structures)
        else:
            # 返回敌方出生点位
            return self.enemy_start_locations[0]

    ## 进攻
    async def attack(self):
        # {UNIT: [n to fight, n to defend]}
        aggressive_units = {UnitTypeId.VOIDRAY: [8, 3]}

        for UNIT in aggressive_units:
            # 攻击模式
            if self.units(UNIT).amount > aggressive_units[UNIT][0] and self.units(UNIT).amount > aggressive_units[UNIT][
                1]:
                for s in self.units(UNIT).idle:
                    await self.do(s.attack(self.find_target(self.state)))
            # 防卫模式
            elif self.units(UNIT).amount > aggressive_units[UNIT][1]:
                if len(self.known_enemy_units) > 0:
                    for s in self.units(UNIT).idle:
                        await self.do(s.attack(random.choice(self.known_enemy_units)))


## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Hard)
], realtime=False)

运行结果如下,红色和粉红色是敌方单位。

创建训练数据

统计资源、人口和军队人口比,在intel方法添加如下代码

        # 追踪资源、人口和军队人口比
        line_max = 50
        mineral_ratio = self.minerals / 1500
        if mineral_ratio > 1.0:
            mineral_ratio = 1.0

        vespene_ratio = self.vespene / 1500
        if vespene_ratio > 1.0:
            vespene_ratio = 1.0

        population_ratio = self.supply_left / self.supply_cap
        if population_ratio > 1.0:
            population_ratio = 1.0

        plausible_supply = self.supply_cap / 200.0

        military_weight = len(self.units(UnitTypeId.VOIDRAY)) / (self.supply_cap - self.supply_left)
        if military_weight > 1.0:
            military_weight = 1.0

        # 农民/人口      worker/supply ratio
        cv2.line(game_data, (0, 19), (int(line_max * military_weight), 19), (250, 250, 200), 3)
        # 人口/200    plausible supply (supply/200.0)
        cv2.line(game_data, (0, 15), (int(line_max * plausible_supply), 15), (220, 200, 200), 3)
        # (人口-现有人口)/人口  population ratio (supply_left/supply)
        cv2.line(game_data, (0, 11), (int(line_max * population_ratio), 11), (150, 150, 150), 3)
        # 气体/1500   gas/1500
        cv2.line(game_data, (0, 7), (int(line_max * vespene_ratio), 7), (210, 200, 0), 3)
        # 晶体矿/1500  minerals minerals/1500
        cv2.line(game_data, (0, 3), (int(line_max * mineral_ratio), 3), (0, 255, 25), 3)

运行结果如下,左下角自上而下?#26469;?#26159;“农民/人口?#20445;?#20154;口/200?#20445;?人口-现有人口)/人口?#20445;?#27668;体/1500?#20445;?#26230;体矿/1500”

采集进攻行为数据,在attack方法中加入如下代码

        if len(self.units(UnitTypeId.VOIDRAY).idle) > 0:
            choice = random.randrange(0, 4)
            target = False
            if self.iteration > self.do_something_after:
                if choice == 0:
                    # 什么都不做
                    wait = random.randrange(20, 165)
                    self.do_something_after = self.iteration + wait

                elif choice == 1:
                    # 攻击离星灵枢纽最近的单位
                    if len(self.known_enemy_units) > 0:
                        target = self.known_enemy_units.closest_to(random.choice(self.units(UnitTypeId.NEXUS)))

                elif choice == 2:
                    # 攻击敌方建筑
                    if len(self.known_enemy_structures) > 0:
                        target = random.choice(self.known_enemy_structures)

                elif choice == 3:
                    # 攻击敌方出生位置
                    target = self.enemy_start_locations[0]

                if target:
                    for vr in self.units(UnitTypeId.VOIDRAY).idle:
                        await self.do(vr.attack(target))
                y = np.zeros(4)
                y[choice] = 1
                print(y)
                self.train_data.append([y, self.flipped])

输出如下结果

···
[1. 0. 0. 0.]
[0. 0. 1. 0.]
[0. 0. 0. 1.]
[0. 0. 1. 0.]
[1. 0. 0. 0.]
···

为了使用self.flipped = cv2.flip(game_data, 0),修改

        flipped = cv2.flip(game_data, 0)
        resized = cv2.resize(flipped, dsize=None, fx=2, fy=2)

        self.flipped = cv2.flip(game_data, 0)
        resized = cv2.resize(self.flipped, dsize=None, fx=2, fy=2)

init 方法添加do_something_after和train_data

    def __init__(self):
        self.ITERATIONS_PER_MINUTE = 165
        self.MAX_WORKERS = 50
        self.do_something_after = 0
        self.train_data = []

采集攻击数据的时候不需要画图,我们在类前加HEADLESS = False,intel方法代码修改如下

        self.flipped = cv2.flip(game_data, 0)

        if not HEADLESS:
            resized = cv2.resize(self.flipped, dsize=None, fx=2, fy=2)
            cv2.imshow('Intel', resized)
            cv2.waitKey(1)

加入on_end方法,只存储胜利的数据,在和代码同级目录新建train_data文件夹

    def on_end(self, game_result):
        print('--- on_end called ---')
        print(game_result)

        if game_result == Result.Victory:
            np.save("train_data/{}.npy".format(str(int(time.time()))), np.array(self.train_data))

完整代码如下

import os
import time

import sc2
from sc2 import run_game, maps, Race, Difficulty, position, Result
from sc2.player import Bot, Computer
from sc2.constants import *
import random
import numpy as np
import cv2

HEADLESS = True
# os.environ["SC2PATH"] = 'F:\StarCraft II'

class SentdeBot(sc2.BotAI):
    def __init__(self):
        # 经过计算,每分钟大约165迭代次数
        self.ITERATIONS_PER_MINUTE = 165
        # 最大农民数量
        self.MAX_WORKERS = 50
        self.do_something_after = 0
        self.train_data = []

    def on_end(self, game_result):
        print('--- on_end called ---')
        print(game_result)

        if game_result == Result.Victory:
            np.save("train_data/{}.npy".format(str(int(time.time()))), np.array(self.train_data))

    async def on_step(self, iteration: int):
        self.iteration = iteration
        await self.scout()
        await self.distribute_workers()
        await self.build_workers()
        await self.build_pylons()
        await self.build_assimilators()
        await self.expand()
        await self.offensive_force_buildings()
        await self.build_offensive_force()
        await self.intel()
        await self.attack()

    ## 侦察
    async def scout(self):
        if len(self.units(UnitTypeId.OBSERVER)) > 0:
            scout = self.units(UnitTypeId.OBSERVER)[0]
            if scout.is_idle:
                enemy_location = self.enemy_start_locations[0]
                move_to = self.random_location_variance(enemy_location)
                print(move_to)
                await self.do(scout.move(move_to))

        else:
            for rf in self.units(UnitTypeId.ROBOTICSFACILITY).ready.noqueue:
                if self.can_afford(UnitTypeId.OBSERVER) and self.supply_left > 0:
                    await self.do(rf.train(UnitTypeId.OBSERVER))

    async def intel(self):
        game_data = np.zeros((self.game_info.map_size[1], self.game_info.map_size[0], 3), np.uint8)

        # UnitTypeId作为key,半径和颜色是value
        draw_dict = {
            UnitTypeId.NEXUS: [15, (0, 255, 0)],
            UnitTypeId.PYLON: [3, (20, 235, 0)],
            UnitTypeId.PROBE: [1, (55, 200, 0)],
            UnitTypeId.ASSIMILATOR: [2, (55, 200, 0)],
            UnitTypeId.GATEWAY: [3, (200, 100, 0)],
            UnitTypeId.CYBERNETICSCORE: [3, (150, 150, 0)],
            UnitTypeId.STARGATE: [5, (255, 0, 0)],
            UnitTypeId.ROBOTICSFACILITY: [5, (215, 155, 0)],

            UnitTypeId.VOIDRAY: [3, (255, 100, 0)],
            # OBSERVER: [3, (255, 255, 255)],
        }

        for unit_type in draw_dict:
            for unit in self.units(unit_type).ready:
                pos = unit.position
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), draw_dict[unit_type][0], draw_dict[unit_type][1], -1)

        # 主基地名称
        main_base_names = ["nexus", "supplydepot", "hatchery"]
        # 记录敌方基地位置
        for enemy_building in self.known_enemy_structures:
            pos = enemy_building.position
            # 不是主基地建筑,画小一些
            if enemy_building.name.lower() not in main_base_names:
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), 5, (200, 50, 212), -1)
        for enemy_building in self.known_enemy_structures:
            pos = enemy_building.position
            if enemy_building.name.lower() in main_base_names:
                cv2.circle(game_data, (int(pos[0]), int(pos[1])), 15, (0, 0, 255), -1)

        for enemy_unit in self.known_enemy_units:

            if not enemy_unit.is_structure:
                worker_names = ["probe", "scv", "drone"]
                # if that unit is a PROBE, SCV, or DRONE... it's a worker
                pos = enemy_unit.position
                if enemy_unit.name.lower() in worker_names:
                    cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (55, 0, 155), -1)
                else:
                    cv2.circle(game_data, (int(pos[0]), int(pos[1])), 3, (50, 0, 215), -1)

        for obs in self.units(UnitTypeId.OBSERVER).ready:
            pos = obs.position
            cv2.circle(game_data, (int(pos[0]), int(pos[1])), 1, (255, 255, 255), -1)


        # 追踪资源、人口和军队人口比
        line_max = 50
        mineral_ratio = self.minerals / 1500
        if mineral_ratio > 1.0:
            mineral_ratio = 1.0

        vespene_ratio = self.vespene / 1500
        if vespene_ratio > 1.0:
            vespene_ratio = 1.0

        population_ratio = self.supply_left / self.supply_cap
        if population_ratio > 1.0:
            population_ratio = 1.0

        plausible_supply = self.supply_cap / 200.0

        military_weight = len(self.units(UnitTypeId.VOIDRAY)) / (self.supply_cap - self.supply_left)
        if military_weight > 1.0:
            military_weight = 1.0

        # 农民/人口      worker/supply ratio
        cv2.line(game_data, (0, 19), (int(line_max * military_weight), 19), (250, 250, 200), 3)
        # 人口/200    plausible supply (supply/200.0)
        cv2.line(game_data, (0, 15), (int(line_max * plausible_supply), 15), (220, 200, 200), 3)
        # (人口-现有人口)/人口  population ratio (supply_left/supply)
        cv2.line(game_data, (0, 11), (int(line_max * population_ratio), 11), (150, 150, 150), 3)
        # 气体/1500   gas/1500
        cv2.line(game_data, (0, 7), (int(line_max * vespene_ratio), 7), (210, 200, 0), 3)
        # 晶体矿/1500  minerals minerals/1500
        cv2.line(game_data, (0, 3), (int(line_max * mineral_ratio), 3), (0, 255, 25), 3)




        # flip horizontally to make our final fix in visual representation:
        self.flipped = cv2.flip(game_data, 0)

        if HEADLESS:
            resized = cv2.resize(self.flipped, dsize=None, fx=2, fy=2)

            cv2.imshow('Intel', resized)
            cv2.waitKey(1)

    def random_location_variance(self, enemy_start_location):
        x = enemy_start_location[0]
        y = enemy_start_location[1]

        x += ((random.randrange(-20, 20)) / 100) * enemy_start_location[0]
        y += ((random.randrange(-20, 20)) / 100) * enemy_start_location[1]

        if x < 0:
            x = 0
        if y < 0:
            y = 0
        if x > self.game_info.map_size[0]:
            x = self.game_info.map_size[0]
        if y > self.game_info.map_size[1]:
            y = self.game_info.map_size[1]

        go_to = position.Point2(position.Pointlike((x, y)))
        return go_to

    # 建造农民
    async def build_workers(self):
        # 星灵枢钮*16(一个基地配备16个农民)大于农民数量并且现有农民数量小于MAX_WORKERS
        if len(self.units(UnitTypeId.NEXUS)) * 16 > len(self.units(UnitTypeId.PROBE)) and len(
                self.units(UnitTypeId.PROBE)) < self.MAX_WORKERS:
            # 星灵枢纽(NEXUS)无队列建造,可以提高晶体矿的利用率,不至于占用资源
            for nexus in self.units(UnitTypeId.NEXUS).ready.noqueue:
                # 是否有50晶体矿建造农民
                if self.can_afford(UnitTypeId.PROBE):
                    await self.do(nexus.train(UnitTypeId.PROBE))

    ## 建造水晶
    async def build_pylons(self):
        ## 供应人口和现有人口之差小于5且建筑不是正在建造
        if self.supply_left < 5 and not self.already_pending(UnitTypeId.PYLON):
            nexuses = self.units(UnitTypeId.NEXUS).ready
            if nexuses.exists:
                if self.can_afford(UnitTypeId.PYLON):
                    await self.build(UnitTypeId.PYLON, near=nexuses.first)

    ## 建造吸收厂
    async def build_assimilators(self):
        for nexus in self.units(UnitTypeId.NEXUS).ready:
            # 在瓦斯泉上建造吸收厂
            vaspenes = self.state.vespene_geyser.closer_than(15.0, nexus)
            for vaspene in vaspenes:
                if not self.can_afford(UnitTypeId.ASSIMILATOR):
                    break
                worker = self.select_build_worker(vaspene.position)
                if worker is None:
                    break
                if not self.units(UnitTypeId.ASSIMILATOR).closer_than(1.0, vaspene).exists:
                    await self.do(worker.build(UnitTypeId.ASSIMILATOR, vaspene))

    ## 开矿
    async def expand(self):
        # (self.iteration / self.ITERATIONS_PER_MINUTE)是一个缓慢递增的值,动态开矿
        if self.units(UnitTypeId.NEXUS).amount < self.iteration / self.ITERATIONS_PER_MINUTE and self.can_afford(
                UnitTypeId.NEXUS):
            await self.expand_now()

    ## 建造进攻性建筑
    async def offensive_force_buildings(self):
        # print(self.iteration / self.ITERATIONS_PER_MINUTE)
        if self.units(UnitTypeId.PYLON).ready.exists:
            pylon = self.units(UnitTypeId.PYLON).ready.random
            # 根据神族建筑科技图,折跃门建造过后才可以建造控制核心
            if self.units(UnitTypeId.GATEWAY).ready.exists and not self.units(UnitTypeId.CYBERNETICSCORE):
                if self.can_afford(UnitTypeId.CYBERNETICSCORE) and not self.already_pending(UnitTypeId.CYBERNETICSCORE):
                    await self.build(UnitTypeId.CYBERNETICSCORE, near=pylon)
            # 否则建造折跃门
            # (self.iteration / self.ITERATIONS_PER_MINUTE)/2 是一个缓慢递增的值
            # elif len(self.units(UnitTypeId.GATEWAY)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
            elif len(self.units(UnitTypeId.GATEWAY)) < 1:
                if self.can_afford(UnitTypeId.GATEWAY) and not self.already_pending(UnitTypeId.GATEWAY):
                    await self.build(UnitTypeId.GATEWAY, near=pylon)
            # 控制核心存在的情况下建造机械台
            if self.units(UnitTypeId.CYBERNETICSCORE).ready.exists:
                if len(self.units(UnitTypeId.ROBOTICSFACILITY)) < 1:
                    if self.can_afford(UnitTypeId.ROBOTICSFACILITY) and not self.already_pending(
                            UnitTypeId.ROBOTICSFACILITY):
                        await self.build(UnitTypeId.ROBOTICSFACILITY, near=pylon)

            # 控制核心存在的情况下建造星门
            if self.units(UnitTypeId.CYBERNETICSCORE).ready.exists:
                if len(self.units(UnitTypeId.STARGATE)) < ((self.iteration / self.ITERATIONS_PER_MINUTE) / 2):
                    if self.can_afford(UnitTypeId.STARGATE) and not self.already_pending(UnitTypeId.STARGATE):
                        await self.build(UnitTypeId.STARGATE, near=pylon)

    ## 造兵
    async def build_offensive_force(self):
        # 无队列化建造
        # for gw in self.units(UnitTypeId.GATEWAY).ready.noqueue:
        #     if not self.units(UnitTypeId.STALKER).amount > self.units(UnitTypeId.VOIDRAY).amount:
        #
        #         if self.can_afford(UnitTypeId.STALKER) and self.supply_left > 0:
        #             await self.do(gw.train(UnitTypeId.STALKER))

        for sg in self.units(UnitTypeId.STARGATE).ready.noqueue:
            if self.can_afford(UnitTypeId.VOIDRAY) and self.supply_left > 0:
                await self.do(sg.train(UnitTypeId.VOIDRAY))

    ## 寻找目标
    def find_target(self, state):
        if len(self.known_enemy_units) > 0:
            # 随机选取敌方单位
            return random.choice(self.known_enemy_units)
        elif len(self.known_enemy_units) > 0:
            # 随机选取敌方建筑
            return random.choice(self.known_enemy_structures)
        else:
            # 返回敌方出生点位
            return self.enemy_start_locations[0]

    ## 进攻
    async def attack(self):
        if len(self.units(UnitTypeId.VOIDRAY).idle) > 0:
            choice = random.randrange(0, 4)
            target = False
            if self.iteration > self.do_something_after:
                if choice == 0:
                    # 什么都不做
                    wait = random.randrange(20, 165)
                    self.do_something_after = self.iteration + wait

                elif choice == 1:
                    # 攻击离星灵枢纽最近的单位
                    if len(self.known_enemy_units) > 0:
                        target = self.known_enemy_units.closest_to(random.choice(self.units(UnitTypeId.NEXUS)))

                elif choice == 2:
                    # 攻击敌方建筑
                    if len(self.known_enemy_structures) > 0:
                        target = random.choice(self.known_enemy_structures)

                elif choice == 3:
                    # 攻击敌方出生位置
                    target = self.enemy_start_locations[0]

                if target:
                    for vr in self.units(UnitTypeId.VOIDRAY).idle:
                        await self.do(vr.attack(target))
                y = np.zeros(4)
                y[choice] = 1
                print(y)
                self.train_data.append([y, self.flipped])


## 启动游戏
run_game(maps.get("AcidPlantLE"), [
    Bot(Race.Protoss, SentdeBot()), Computer(Race.Terran, Difficulty.Medium)
], realtime=False)

可以看到train_data文件夹下存储了胜利数据

创建神经网络模型

sequential model如下

选择Sequential模型,dense层,dropout和flatten,将数据flatten之后再进入dense层。最终,我们使用卷积神经网络,所以我们需要Conv2D和MaxPooling2D。我们想可视化输出模型训练结果,再加上TensorBoard。
存储数据用numpy,os操作IO,random用来进行随机操作。

import keras
from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPooling2D
from keras.callbacks import TensorBoard
import numpy as np
import os
import random

首先,创建模型

model = Sequential()

名?#24335;?#37322;:
Sequential:序贯模型,序贯模型是函数式模型的简?#22253;媯?#20026;最简单的线性、从?#36820;?#23614;的结构?#25215;潁环?#21449;
Conv2D:图像的空域卷积
MaxPooling2D:空域信号施加到最大值池化
Dropout:用于防止过拟合
Flatten:输入"压平",把多维的输入一维化,常用在卷积层到全连接层的过渡
Dense:全连接层
接下来,创建hidden卷积层。

model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=(176, 200, 3),
                 activation='relu'))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(64, (3, 3), padding='same',
                 activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(128, (3, 3), padding='same',
                 activation='relu'))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

创建全连接层

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))

输出层

model.add(Dense(4, activation='softmax'))

为神经网络设置compile属性

learning_rate = 0.0001
opt = keras.optimizers.adam(lr=learning_rate, decay=1e-6)

model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

compile方法解释:
compile(self, optimizer, loss, metrics=None, sample_weight_mode=None)
optimizer: 字符串(预定义优化器名)或优化器对象,参?#21152;?#21270;器
loss: 字符串(预定义损失函数名)或目标函数,参考损失函数
metrics: 列表,包含评估模型在训练和测试时的网络?#38405;?#30340;指标,典型用法是metrics=[‘accuracy’]
sample_weight_mode:如果你需要按时间步为样本赋权(2D权矩阵),将该值设为“temporal”。
默认为“None?#20445;?#20195;表按样本赋权(1D权)。在下面fit函数的解释中有相关的参考内容。
kwargs: 使用TensorFlow作为后端请忽略该参数,若使用Theano作为后端,kwargs的?#21040;?#20250;传递给 K.function

通过TensorBoard记?#38469;?#25454;

tensorboard = TensorBoard(log_dir="logs/stage1")

完整代码如下:

# -*- encoding: utf-8 -*-
'''
@File    :   building-model.py    
@Modify Time      @Author       @Desciption
------------      -------       -----------
2019/11/26 22:37   Jonas           None
'''
 
import keras
from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPooling2D
from keras.callbacks import TensorBoard
import numpy as np
import os
import random

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=(176, 200, 3),
                 activation='relu'))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(64, (3, 3), padding='same',
                 activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(128, (3, 3), padding='same',
                 activation='relu'))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))

model.add(Dense(4, activation='softmax'))

learning_rate = 0.0001
opt = keras.optimizers.adam(lr=learning_rate, decay=1e-6)

model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

tensorboard = TensorBoard(log_dir="logs/stage1")

tran_data_dir = "train_data"

训练神经网络模型

train_data现成的训练文件链接,当然,你?#37096;?#20197;?#32422;?#35757;练。
链接:https://pan.baidu.com/s/1X0eRXakbmqmL_It9QbLYfQ
提取码:trow
一次处理200个文件

for i in range(hm_epochs):
    print(i)
    current = 0
    increment = 200
    not_maximum = True

获取文件数组首元素,即choice

    while not_maximum:
        print("现在正在做 {}:{}".format(current,current+increment))
        no_attacks = []
        attack_closest_to_nexus = []
        attack_enemy_structures = []
        attack_enemy_start = []
        for file in all_files[current:current+increment]:
            full_path = os.path.join(train_data_dir,file)
            data = np.load(full_path)
            data = list(data)
            for d in data:
                # 返回最大值的索引
                choice = np.argmax(d[0])
                if choice == 0:
                    no_attacks.append(d)
                elif choice == 1:
                    attack_closest_to_nexus.append(d)
                elif choice == 2:
                    attack_enemy_structures.append(d)
                elif choice == 3:
                    attack_enemy_start.append(d)

统计四种进攻策略的数量

# 对数据进行统计
def check_data():
    choices = {"no_attacks": no_attacks,
               "attack_closest_to_nexus": attack_closest_to_nexus,
               "attack_enemy_structures": attack_enemy_structures,
               "attack_enemy_start": attack_enemy_start}

    total_data = 0

    lengths = []
    for choice in choices:
        print("Length of {} is: {}".format(choice, len(choices[choice])))
        total_data += len(choices[choice])
        lengths.append(len(choices[choice]))

    print("Total data length now is:",total_data)
    return lengths

将四种进攻策略的数据规范到同一长度

        lowest_data = min(lengths)
        random.shuffle(no_attacks)
        random.shuffle(attack_closest_to_nexus)
        random.shuffle(attack_enemy_structures)
        random.shuffle(attack_enemy_start)
        # 缩小数据集
        no_attacks = no_attacks[:lowest_data]
        attack_closest_to_nexus = attack_closest_to_nexus[:lowest_data]
        attack_enemy_structures = attack_enemy_structures[:lowest_data]
        attack_enemy_start = attack_enemy_start[:lowest_data]

        check_data()

把四种进攻策略合并成一个大list

        train_data = no_attacks + attack_closest_to_nexus + attack_enemy_structures + attack_enemy_start

100个验证集,len(train_data-100)个训练集

        x_train = np.array([i[1] for i in train_data[:-test_size]]).reshape(-1, 176, 200, 3)
        y_train = np.array([i[0] for i in train_data[:-test_size]])

        x_test = np.array([i[1] for i in train_data[-test_size:]]).reshape(-1, 176, 200, 3)
        y_test = np.array([i[0] for i in train_data[-test_size:]])

拟合数据
fit方法参数解释:

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, >validation_steps=None)
x: 训练数据的 Numpy 数组(如果模型只有一个输入), 或者是 Numpy 数组的列表(如果模型有多个输入)。 如果模型中的输入层被命名,你?#37096;?#20197;传递一个?#20540;洌?#23558;输入层名称?#25104;?#21040; Numpy 数组。 如果从本地框架张量馈送(例如 >TensorFlow 数据张量)数据,x 可以是 None(默认)。
y: 目标(标签)数据的 Numpy 数组(如果模型只有一个输出), 或者是 Numpy 数组的列表(如果模型有多个输出)。 如果模型中的输出层被命名,你?#37096;?#20197;传递一个?#20540;洌?#23558;输出层名称?#25104;?#21040; Numpy 数组。 如果从本地框架张量馈送(例>如 TensorFlow 数据张量)数据,y 可以是 None(默认)。
batch_size: 整数或 None。?#30475;?#26799;度更新的样本数。如果未指定,默认为 32。
verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。
validation_data: 元组 (x_val,y_val) 或元组 (x_val,y_val,val_sample_weights), 用来评估损失,以?#38712;?#27599;轮结束时的任?#25991;?#22411;度量指标。 模型将不会在这个数据上进行训练。这个参数会覆盖 validation_split。
shuffle: ?#32423;?#20540;(是否在每轮迭代之前混洗数据)或者 字符串 (batch)。 batch 是处理 HDF5 数据限制的特殊选项,它对一个 batch 内部的数据进行混洗。 当 steps_per_epoch 非 None 时,这个参数无效。
callbacks: 一系列的 keras.callbacks.Callback 实例。一系列可以在训练?#31508;?#29992;的回调函数。

model.fit(x_train,y_train,batch_size=batch_size,validation_data=(x_test,y_test),shuffle=True,verbose=1,
                  callbacks=[tensorboard])

保存模型

        model.save("BasicCNN-{}-epochs-{}-LR-STAGE1".format(hm_epochs, learning_rate))
        current += increment
        if current > maximum:
            not_maximum = False

?#38142;耍?#23436;整代码如下

import keras
from keras.models import Sequential
from keras.layers import Dense,Dropout,Flatten,Conv2D,MaxPooling2D
from keras.callbacks import TensorBoard
import numpy as np
import os
import random

# 创建序贯模型
model = Sequential()
# 创建hidden卷积层
model.add(Conv2D(32, (3, 3), padding='same',
                 input_shape=(176, 200, 3),
                 activation='relu'))
model.add(Conv2D(32, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(64, (3, 3), padding='same',
                 activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Conv2D(128, (3, 3), padding='same',
                 activation='relu'))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))

model.add(Dense(4, activation='softmax'))

learning_rate = 0.0001
opt = keras.optimizers.adam(lr=learning_rate, decay=1e-6)

model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

tensorboard = TensorBoard(log_dir="logs/stage1")

train_data_dir = "train_data"

# 对数据进行统计
def check_data():
    choices = {"no_attacks": no_attacks,
               "attack_closest_to_nexus": attack_closest_to_nexus,
               "attack_enemy_structures": attack_enemy_structures,
               "attack_enemy_start": attack_enemy_start}

    total_data = 0

    lengths = []
    for choice in choices:
        print("Length of {} is: {}".format(choice, len(choices[choice])))
        total_data += len(choices[choice])
        lengths.append(len(choices[choice]))

    print("Total data length now is:",total_data)
    return lengths

# 循环次数
hm_epochs = 10

for i in range(hm_epochs):
    print(i)
    current = 0
    increment = 200
    not_maximum = True
    # 遍历所有文件
    all_files = os.listdir(train_data_dir)
    maximum = len(all_files)
    random.shuffle(all_files)

    while not_maximum:
        print("现在正在做 {}:{}".format(current,current+increment))
        no_attacks = []
        attack_closest_to_nexus = []
        attack_enemy_structures = []
        attack_enemy_start = []
        for file in all_files[current:current+increment]:
            full_path = os.path.join(train_data_dir,file)
            data = np.load(full_path)
            data = list(data)
            for d in data:
                # 返回最大值的索引
                choice = np.argmax(d[0])
                if choice == 0:
                    no_attacks.append(d)
                elif choice == 1:
                    attack_closest_to_nexus.append(d)
                elif choice == 2:
                    attack_enemy_structures.append(d)
                elif choice == 3:
                    attack_enemy_start.append(d)

        lengths = check_data()

        lowest_data = min(lengths)
        random.shuffle(no_attacks)
        random.shuffle(attack_closest_to_nexus)
        random.shuffle(attack_enemy_structures)
        random.shuffle(attack_enemy_start)
        # 缩小数据集
        no_attacks = no_attacks[:lowest_data]
        attack_closest_to_nexus = attack_closest_to_nexus[:lowest_data]
        attack_enemy_structures = attack_enemy_structures[:lowest_data]
        attack_enemy_start = attack_enemy_start[:lowest_data]

        check_data()

        train_data = no_attacks + attack_enemy_start + attack_enemy_structures + attack_closest_to_nexus
        random.shuffle(train_data)
        test_size = 100
        batch_size = 128
        # 重置为四维数组,行未知,留给numpy计算
        x_train = np.array([i[1] for i in train_data[:-test_size]]).reshape(-1, 176, 200, 3)
        y_train = np.array([i[0] for i in train_data[:-test_size]])

        x_test = np.array([i[1] for i in train_data[-test_size:]]).reshape(-1, 176, 200, 3)
        y_test = np.array([i[0] for i in train_data[-test_size:]])

        model.fit(x_train,y_train,batch_size=batch_size,validation_data=(x_test,y_test),shuffle=True,verbose=1,
                  callbacks=[tensorboard])

        model.save("BasicCNN-{}-epochs-{}-LR-STAGE1".format(hm_epochs, learning_rate))
        current += increment
        if current > maximum:
            not_maximum = False
posted @ 2019-11-21 23:28  Rest探路者  阅读(...)  评论(... 编辑 收藏
快乐十分开奖号码搜索