-- --8<--8<--8<--8<--
--
-- Copyright (C) 2011 Smithsonian Astrophysical Observatory
--
-- This file is part of chandra.saotrace.aperture
--
-- chandra.saotrace.aperture is free software: you can redistribute it
-- and/or modify it under the terms of the GNU General Public License
-- as published by the Free Software Foundation, either version 3 of
-- the License, or (at your option) any later version.
--
-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.
--
-- You should have received a copy of the GNU General Public License
-- along with this program.  If not, see <http://www.gnu.org/licenses/>.
--
-- -->8-->8-->8-->8--

local require = require
local setmetatable = setmetatable
local select = select
local assert = assert
local pairs = pairs
local error = error


local string = require( 'string' )
local RDB = require( 'rdb' )
local paths   = require( 'saotrace.suplib.paths' )
local strings = require( 'saotrace.suplib.strings' )
local aperture = require( 'saotrace.aperture' )
local funcs    = require( 'chandra.saotrace.aperture.baffles.funcs' )


module( ... )

require( 'Strict' ).restrict(_M)


local vobj = require( 'validate.args' ):new()
vobj:setopts{ named = true }


--------------------------------------------------------------------------
-- data types
--
--   displacement {
--     dx              -- (mm)           decenter
--     dy              -- (mm)           decenter
--     dz              -- (mm)           despace
--   }
--   disorientation {
--     azmis           -- (radians)
--     elmis           -- (radians)
--     clock           -- (radians)
--   }
--
--  annulus_info {
--      ri             -- (mm)       annulus inner radius
--      ro             -- (mm)       annulus outer radius
--      dri            -- (mm)       tolerance on inner radius
--      dro            -- (mm)       tolerance on outer radius
--      type           -- -1 = no annulus, 0 = transp, 1 = opaque
--  }
--
--  strut_info {
--      num            --            number of struts
--      width          -- (mm)       strut width
--      theta0         -- (radians)  offset of first strut
--      delta_theta    -- (radians)  angle between strut centerlines
--      theta0         -- (radians)  offset of first strut
--      R_i            -- (mm)       innermost radius for strut
--      R_o            -- (mm)       outermost radius for strut
--      type           -- strut | rect
--  }
--
--  baffle {
--      id             --            identifier
--      z              -- (mm)       axial position of baffle
--      annulus_info   -- table      annulus data
--      strut_info     -- table      strut data
--      displacement   -- table      position    tweaks
--      disorientation -- table      orientation tweaks
--  }


--------------------------------------------------------------------------
-- new
--
-- create a new baffle set.
--
-- if called with arguments
--    prog, rdbfile, shell, coord_type
-- then the baffle descriptions are read in from rdbfile

__index = nil

function new( self, ... )

   local obj  = {}
   setmetatable( obj, self )

   self.__index = self

   if select('#', ... ) > 0 then

      obj:read_config( ... )

   end

   return obj
end

--------------------------------------------------------------------------
-- read_config
--
-- read baffle assembly descriptions from an rdbfile;
-- construct a baffle assembly
--
-- Inputs: a table containing...
--
--   assembly_name  - name of the assembly
--   shell          - shell number to be constructed
--   config_db      - rdb file containing information on how to
--                    construct the assemblies
--   coord_type     - coordinate conversion type ('HRMA' or 'XRCF')
--                    used in converting the database values to
--                    raytrace coordinates. defaults to 'XRCF'
--   insert_struts  - insert struts?  defaults to true.
--   tol_scale      - tolerance scale.  see apply_annulus_tolerances()

read_config_vspec =
      {
	 assembly_name = { type = 'string',
			   required = true, },
	 config_db     = { type = 'string',
			   required = true, },
	 shell         = { type = 'posnum',
			   required = true, },
	 coord_type    = { enum = { 'HRMA', 'XRCF' },
			   default = 'XRCF',
			   precall = function( arg )
					if arg then
					   return true, string.upper(arg)
					end
				     end, },
	 insert_struts = { type = 'boolean',
			   default = true },
      }


function read_config( self, ... )

   local ok, args = vobj:validate( read_config_vspec, ... )

   assert( ok, args )

   for k,v in pairs( args ) do
      self[k] = v
   end

   self.coord_rot  = 0.0
   if self.coord_type == 'XRCF' then
      self.coord_rot = 180.0 * funcs.deg2rad
   end

   local rdb = RDB( args.config_db )

   local match
   local data = rdb:read( )

   while ( not match and data ) do
      if self.assembly_name == data.id then
	 match = true
      else
	 data = rdb:read( )
      end
   end

   if not match then
      error( string.format( "couldn't find assembly %s in %s\n",
			    self.assembly_name, args.config_db ) )
   end

   self.tol_scale      = data.tol_scale
   self.assembly_arg   = data.assembly_arg

   self.displacement   =
      funcs.construct_displacement(
					 data.delta_X,
					 data.delta_Y,
					 data.delta_Z
				      )

   self.disorientation =
      funcs.construct_disorientation(
					   data.azmis_H,
					   data.elmis_H,
					   data.clock_H
					)

   -- interpolate any possible environmental variables
   self.baffle_cfg = strings.interp( data.baffle_cfg )

   -- if baffle_cfg is not absolute, prepend directory of config_db
   if not self.baffle_cfg:find( "^/" ) then

      local dirname, basename = paths.split( self.config_db )

      self.baffle_cfg = dirname .. "/" .. self.baffle_cfg
   end

   -- read in baffle information; construct baffle set
   self:read_baffle_info()

   return self

end

--------------------------------------------------------------------------
-- read_baffle_info
--
-- read baffle assembly descriptions from an rdbfile;
-- construct a baffle assembly
--

function read_baffle_info( self )

   local errpfx = 'BaffleSet:read_baffle_info: ' .. self.assembly_name .. ': '

   local rdb = RDB( self.baffle_cfg )

   -- find the entry which corresponds to this shell


   -- Record: z - axial coordinate, ri - inner radius, ro - outer radius
   -- Dimensions are in mm; in raytrace coordinates.

   self.first_baffle = 1
   self.last_baffle  = 0

   self.baffle  = {}
   local baffle = {}

   local pid = 0
   local data = rdb:read( )

   while ( data ) do

      if self.shell == data.shell then

	 pid = pid + 1

	 local baffle = {}

	 baffle.id = data.id
	 baffle.z  = -data.X

	 baffle.displacement =
	    funcs.construct_displacement( data.delta_X,
					  data.delta_Y,
					  data.delta_Z )
	 baffle.disorientation =
	    funcs.construct_disorientation( data.azmis_H,
					    data.elmis_H,
					    data.clock_H )

	 baffle.annulus_info = funcs.construct_annulus( data )

	 if not self.insert_struts or nil == data.strut_type then
	    data.strut_type = 'none'
	 end

	 baffle.strut_info = funcs.construct_struts( data )

	 self.baffle[pid] = baffle

      end

      data = rdb:read( )

   end

   self.last_baffle = pid

   if 0 == pid then
      error( string.format( "%s: couldn't find shell %d in %s\n",
			    errpfx, self.shell, self.baffle_cfg ) )
   end

   self:apply_annulus_tolerances()

   -----------------------------------------
   -- body center z in raytrace coordinates:

   self.z_center = (self.baffle[self.first_baffle].z
		    + self.baffle[self.last_baffle].z)/2.0

   for pid = self.first_baffle, self.last_baffle do
      self.baffle[pid].z = self.baffle[pid].z - self.z_center
   end

   return self
end


--------------------------------------------------------------------------
-- apply_annulus_tolerances
--
-- Inputs:
--
-- apply tolerances to baffle radii, scaled by tol_scale.
--    tol_scale = -1   -- minimum ghosting, maximum vignetting
--                 0   -- nominal baffles; zero tolerances applied
--                 1   -- maximum ghosting, minimum vignetting
-- Nonintegral values for tol_scale are permitted.

function apply_annulus_tolerances( self )

   for pid = self.first_baffle, self.last_baffle do

      if self.baffle[pid].annulus_info == 'boolean' then

       self.baffle[pid].annulus_info.dri =
	  self.baffle[pid].annulus_info.dri * self.tol_scale

       self.baffle[pid].annulus_info.dro =
	  self.baffle[pid].annulus_info.dro * self.tol_scale

       self.baffle[pid].annulus_info.ri
                     = self.baffle[pid].annulus_info.ri
                     - self.baffle[pid].annulus_info.dri

       self.baffle[pid].annulus_info.ro
                     = self.baffle[pid].annulus_info.ro
                     + self.baffle[pid].annulus_info.dro

    end

  end


end

--------------------------------------------------------------------------
-- construct_baffles
--
-- Inputs:
--
--  debug_baffle -- id of baffle to debug, or 'all' for whole set

function construct_baffles( self, debug_baffle )

  -- adjust position of assembly as a whole
  --

  -- rotate according to the raytrace system in use.  the default
  -- is to map hrma to raytrace; xrcf to raytrace is rotated about
  -- the system mechanical axis by 180 degrees.

  aperture.rotate_z( self.coord_rot )

  -- the baffle set is specified in body centered coordinates
  -- move it to the correct position in raytrace coordinates
  -- note that z_center is the position of the body
  -- center in raytrace coordinates.

  aperture.translate(0,0, self.z_center)

  funcs.apply_displacement_disorientation( self.displacement,
                                                 self.disorientation )

  -- print_assembly()           -- for tests only

  local id = self.first_baffle

  for id = self.first_baffle, self.last_baffle do

    funcs.apply_displacement_disorientation(
       self.baffle[id].displacement,
       self.baffle[id].disorientation )
      -- tweak the baffle

    aperture.begin_assembly( self.assembly_arg )

      -- print_photon( 'raw_project' )

    funcs.do_baffle( self.baffle[id], debug_baffle )

      -- if ( self.baffle[id].id == debug_baffle or debug_baffle == "all" ) then
      --   print_assembly()
      -- end

    aperture.end_assembly()

  end

end

--------------------------------------------------------------------------
-- get_body_baffleset_center_axial_station
--
-- return body center axial station of baffleset as a whole

function get_baffleset_body_center_axial_station( self )

  return self.z_center

end

--------------------------------------------------------------------------
-- get_baffleset_displacement
--
-- return displacement of baffleset as a whole

function get_baffleset_displacement( self )

  return self.displacement

end

--------------------------------------------------------------------------
-- get_baffleset_disorientation
--
-- return disorientation of baffleset as a whole

function get_baffleset_disorientation( self )

  return self.disorientation

end

--------------------------------------------------------------------------
-- get_first_baffle_number
--
-- return id number of first baffle

function get_first_baffle_number( self )

  return self.first_baffle

end

--------------------------------------------------------------------------
-- get_last_baffle_number
--
-- return id number of last baffle

function get_last_baffle_number( self )

  return self.last_baffle

end

--------------------------------------------------------------------------
-- get_baffle_id
--
-- return id for baffle pid
--
-- Inputs:
--
--   pid -- baffle index number

function get_baffle_id( self, pid )

  return self.baffle[pid].id

end

--------------------------------------------------------------------------
-- get_baffle_net_displacement
--
-- return net displacement of baffle pid
-- (includes displacement of baffle set as a whole)
--
-- Inputs:
--
--   pid -- baffle index number

function get_baffle_net_displacement( self, pid )

  local dsp = self.baffle[pid].displacement
  dsp.dx = dsp.dx + self.displacement.dx
  dsp.dy = dsp.dy + self.displacement.dy
  dsp.dz = dsp.dy + self.displacement.dz
  return dsp

end

--------------------------------------------------------------------------
-- get_baffle_displacement
--
-- return displacement of baffle pid
--
-- Inputs:
--
--   pid -- baffle index number

function get_baffle_displacement( self, pid )

  return self.baffle[pid].displacement

end

--------------------------------------------------------------------------
-- get_baffle_disorientation
--
-- return disorientation of baffle pid
--
-- Inputs:
--
--   pid -- baffle index number

function get_baffle_disorientation( self, pid )

  return self.baffle[pid].disorientation

end

--------------------------------------------------------------------------
-- get_baffle_axial_station
--
-- return axial station of baffle pid
-- NB:  does not include the effect of any baffle displacements or
--      disorientation
--
-- Inputs:
--
--   pid -- baffle index number

function get_baffle_axial_station( self, pid )

  return self.baffle[pid].z + self.z_center

end


--------------------------------------------------------------------------
-- get_annulus_info
--
-- return annulus_info for baffle pid
--
-- Inputs:
--
--   pid -- baffle index number

function get_annulus_info( self, pid )

  return self.baffle[pid].annulus_info

end



--------------------------------------------------------------------------
-- get_strut_info
--
-- return strut_info for baffle pid
--
-- Inputs:
--
--   pid -- baffle index number

function get_strut_info( self, pid )

  return self.baffle[pid].strut_info

end
